211 lines
5.8 KiB
Python
211 lines
5.8 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import gzip
|
||
import heapq
|
||
import click
|
||
import csv
|
||
import polars as pl
|
||
|
||
from multiprocessing import Pool
|
||
|
||
from glob import glob
|
||
from tqdm import tqdm
|
||
|
||
|
||
def amino_acid_to_codon():
|
||
"""
|
||
简化的氨基酸到密码子转换函数
|
||
|
||
参数:
|
||
amino_acid (str): 单字母氨基酸代码
|
||
|
||
返回:
|
||
list: 可能的密码子列表
|
||
"""
|
||
genetic_code = {
|
||
'A': ['GCT', 'GCC', 'GCA', 'GCG'],
|
||
'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'],
|
||
'N': ['AAT', 'AAC'],
|
||
'D': ['GAT', 'GAC'],
|
||
'C': ['TGT', 'TGC'],
|
||
'E': ['GAA', 'GAG'],
|
||
'Q': ['CAA', 'CAG'],
|
||
'G': ['GGT', 'GGC', 'GGA', 'GGG'],
|
||
'H': ['CAT', 'CAC'],
|
||
'I': ['ATT', 'ATC', 'ATA'],
|
||
'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'],
|
||
'K': ['AAA', 'AAG'],
|
||
'M': ['ATG'],
|
||
'F': ['TTT', 'TTC'],
|
||
'P': ['CCT', 'CCC', 'CCA', 'CCG'],
|
||
'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'],
|
||
'T': ['ACT', 'ACC', 'ACA', 'ACG'],
|
||
'W': ['TGG'],
|
||
'Y': ['TAT', 'TAC'],
|
||
'V': ['GTT', 'GTC', 'GTA', 'GTG'],
|
||
'*': ['TAA', 'TAG', 'TGA'],
|
||
}
|
||
|
||
codes = []
|
||
for val in genetic_code.values():
|
||
codes += val
|
||
|
||
return set(codes) # genetic_code.get(amino_acid.upper(), [])
|
||
|
||
__CODONS__ = amino_acid_to_codon()
|
||
|
||
|
||
def reader(path: str, rt_len: int = 24):
|
||
"""
|
||
流式读取 CSV 文件,逐行返回 dict。
|
||
内存占用恒定(只缓存一行),适合 GB 级文件。
|
||
"""
|
||
with gzip.open(path, "rt", newline="") as f:
|
||
for row in csv.DictReader(f):
|
||
try:
|
||
if float(row["RTlength"]) <= rt_len:
|
||
yield row
|
||
except TypeError:
|
||
continue
|
||
|
||
|
||
def __check_target__(key: str):
|
||
if ">" in key:
|
||
key = key.replace(">", "_")
|
||
|
||
keys = key.split("_")
|
||
|
||
return keys[-1] in __CODONS__
|
||
|
||
|
||
def __decode_codon_n__(key: str) -> str:
|
||
# BRIP1_AA580_CTC_CTT
|
||
|
||
if ">" in key:
|
||
key = key.replace(">", "_")
|
||
|
||
keys = key.split("_")
|
||
|
||
try:
|
||
res = []
|
||
for x, y in zip(keys[-2], keys[-1]):
|
||
if x == y:
|
||
res.append(x)
|
||
else:
|
||
res.append("N")
|
||
keys[-1] = "".join(res)
|
||
except IndexError as err:
|
||
print(keys)
|
||
raise err
|
||
return "_".join(keys)
|
||
|
||
def __call_func__(args):
|
||
u""" 实际处理代码 """
|
||
f, outdir, top_n, degenerate = args
|
||
|
||
data = {}
|
||
|
||
# 读取文件
|
||
for rec in tqdm(reader(f)):
|
||
|
||
# 根据设定好的sequence名称
|
||
key = rec["sequence_name"]
|
||
|
||
if not __check_target__(key):
|
||
# 如果target不是已知的编码氨基酸的codon则跳过
|
||
continue
|
||
|
||
if degenerate:
|
||
try:
|
||
key = __decode_codon_n__(rec["sequence_name"])
|
||
rec["orig_seq_name"] = rec.pop("sequence_name")
|
||
rec["sequence_name"] = key
|
||
except IndexError:
|
||
continue
|
||
|
||
if key not in data:
|
||
data[key] = []
|
||
|
||
# 数据heap化
|
||
if "DeepCas9score" in rec.keys():
|
||
k = "DeepCas9score"
|
||
elif "PRIDICT2_0_editing_Score_deep_K562" in rec.keys():
|
||
k = "PRIDICT2_0_editing_Score_deep_K562"
|
||
else:
|
||
print(f, rec)
|
||
continue
|
||
# raise ValueError(f"PRIDICT2_0_editing_Score_deep_K562 not exists in {f}")
|
||
|
||
try:
|
||
score = float(rec[k])
|
||
except (ValueError, KeyError) as e:
|
||
print(f"Warning: Skipping invalid record in {f}: {rec}")
|
||
continue # 或 raise,根据需求
|
||
|
||
if len(data[key]) < top_n:
|
||
heapq.heappush(data[key], (score, rec))
|
||
else:
|
||
try:
|
||
if score > data[key][0][0]:
|
||
heapq.heapreplace(data[key], (score, rec))
|
||
except TypeError as err:
|
||
print(err)
|
||
print(key)
|
||
print(score)
|
||
print(len(data[key]))
|
||
raise err
|
||
|
||
# 第二遍:整理结果(按 score 降序)
|
||
final_records = []
|
||
for heap in data.values():
|
||
# 从堆中取出并按 score 降序排列
|
||
sorted_recs = [rec for _, rec in sorted(heap, key=lambda x: x[0], reverse=True)]
|
||
final_records.extend(sorted_recs)
|
||
|
||
if not final_records:
|
||
print(f"No valid records in {f}, skipping output.")
|
||
return
|
||
|
||
# 安全写入 CSV(使用 csv 模块)
|
||
output_path = os.path.join(outdir, os.path.basename(f))
|
||
with gzip.open(output_path, "wt+", newline="", encoding="utf-8") as w:
|
||
writer = csv.DictWriter(w, fieldnames=final_records[0].keys(), quoting=csv.QUOTE_MINIMAL)
|
||
writer.writeheader()
|
||
writer.writerows(final_records)
|
||
|
||
|
||
@click.command()
|
||
@click.option("-i", "--indir", type=str, help="字符串形式的输入路径,可以*通配多个文件和目录")
|
||
@click.option("-o", "--outdir", type=str, help="输出目录")
|
||
@click.option("-t", "--top-n", type=int, help="选择前几", default=3)
|
||
@click.option("-n", "--degenerate", is_flag=True, help="是否使用兼并碱基")
|
||
@click.argument('args', nargs=-1) # 捕获所有位置参数
|
||
def main(indir, outdir, top_n, degenerate, args):
|
||
|
||
if not indir and len(args) > 0:
|
||
indir = args[0]
|
||
if not outdir and len(args) > 0:
|
||
outdir = args[-1]
|
||
|
||
if indir == outdir:
|
||
raise ValueError("indir and outdir should not be the same")
|
||
|
||
os.makedirs(outdir, exist_ok=True)
|
||
|
||
# 获取输入文件,生成参数
|
||
args = [[f, outdir, top_n, degenerate] for f in glob(indir)]
|
||
|
||
# for arg in args:
|
||
# print(arg[0])
|
||
# __call_func__(arg)
|
||
|
||
with Pool(len(args)) as p:
|
||
list(tqdm(p.imap(__call_func__, args), total=len(args)))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|
||
|