#!/usr/bin/env python # -*- coding: utf-8 -*- import os import gzip import heapq import click import csv import polars as pl from multiprocessing import Pool from glob import glob from tqdm import tqdm def amino_acid_to_codon(): """ 简化的氨基酸到密码子转换函数 参数: amino_acid (str): 单字母氨基酸代码 返回: list: 可能的密码子列表 """ genetic_code = { 'A': ['GCT', 'GCC', 'GCA', 'GCG'], 'R': ['CGT', 'CGC', 'CGA', 'CGG', 'AGA', 'AGG'], 'N': ['AAT', 'AAC'], 'D': ['GAT', 'GAC'], 'C': ['TGT', 'TGC'], 'E': ['GAA', 'GAG'], 'Q': ['CAA', 'CAG'], 'G': ['GGT', 'GGC', 'GGA', 'GGG'], 'H': ['CAT', 'CAC'], 'I': ['ATT', 'ATC', 'ATA'], 'L': ['TTA', 'TTG', 'CTT', 'CTC', 'CTA', 'CTG'], 'K': ['AAA', 'AAG'], 'M': ['ATG'], 'F': ['TTT', 'TTC'], 'P': ['CCT', 'CCC', 'CCA', 'CCG'], 'S': ['TCT', 'TCC', 'TCA', 'TCG', 'AGT', 'AGC'], 'T': ['ACT', 'ACC', 'ACA', 'ACG'], 'W': ['TGG'], 'Y': ['TAT', 'TAC'], 'V': ['GTT', 'GTC', 'GTA', 'GTG'], '*': ['TAA', 'TAG', 'TGA'], } codes = [] for val in genetic_code.values(): codes += val return set(codes) # genetic_code.get(amino_acid.upper(), []) __CODONS__ = amino_acid_to_codon() def reader(path: str, rt_len: int = 24): """ 流式读取 CSV 文件,逐行返回 dict。 内存占用恒定(只缓存一行),适合 GB 级文件。 """ with gzip.open(path, "rt", newline="") as f: for row in csv.DictReader(f): try: if float(row["RTlength"]) <= rt_len: yield row except TypeError: continue def __check_target__(key: str): if ">" in key: key = key.replace(">", "_") keys = key.split("_") return keys[-1] in __CODONS__ def __decode_codon_n__(key: str) -> str: # BRIP1_AA580_CTC_CTT if ">" in key: key = key.replace(">", "_") keys = key.split("_") try: res = [] for x, y in zip(keys[-2], keys[-1]): if x == y: res.append(x) else: res.append("N") keys[-1] = "".join(res) except IndexError as err: print(keys) raise err return "_".join(keys) def __call_func__(args): u""" 实际处理代码 """ f, outdir, top_n, degenerate = args data = {} # 读取文件 for rec in tqdm(reader(f)): # 根据设定好的sequence名称 key = rec["sequence_name"] if not __check_target__(key): # 如果target不是已知的编码氨基酸的codon则跳过 continue if degenerate: try: key = __decode_codon_n__(rec["sequence_name"]) rec["orig_seq_name"] = rec.pop("sequence_name") rec["sequence_name"] = key except IndexError: continue if key not in data: data[key] = [] # 数据heap化 if "DeepCas9score" in rec.keys(): k = "DeepCas9score" elif "PRIDICT2_0_editing_Score_deep_K562" in rec.keys(): k = "PRIDICT2_0_editing_Score_deep_K562" else: print(f, rec) continue # raise ValueError(f"PRIDICT2_0_editing_Score_deep_K562 not exists in {f}") try: score = float(rec[k]) except (ValueError, KeyError) as e: print(f"Warning: Skipping invalid record in {f}: {rec}") continue # 或 raise,根据需求 if len(data[key]) < top_n: heapq.heappush(data[key], (score, rec)) else: try: if score > data[key][0][0]: heapq.heapreplace(data[key], (score, rec)) except TypeError as err: print(err) print(key) print(score) print(len(data[key])) raise err # 第二遍:整理结果(按 score 降序) final_records = [] for heap in data.values(): # 从堆中取出并按 score 降序排列 sorted_recs = [rec for _, rec in sorted(heap, key=lambda x: x[0], reverse=True)] final_records.extend(sorted_recs) if not final_records: print(f"No valid records in {f}, skipping output.") return # 安全写入 CSV(使用 csv 模块) output_path = os.path.join(outdir, os.path.basename(f)) with gzip.open(output_path, "wt+", newline="", encoding="utf-8") as w: writer = csv.DictWriter(w, fieldnames=final_records[0].keys(), quoting=csv.QUOTE_MINIMAL) writer.writeheader() writer.writerows(final_records) @click.command() @click.option("-i", "--indir", type=str, help="字符串形式的输入路径,可以*通配多个文件和目录") @click.option("-o", "--outdir", type=str, help="输出目录") @click.option("-t", "--top-n", type=int, help="选择前几", default=3) @click.option("-n", "--degenerate", is_flag=True, help="是否使用兼并碱基") @click.argument('args', nargs=-1) # 捕获所有位置参数 def main(indir, outdir, top_n, degenerate, args): if not indir and len(args) > 0: indir = args[0] if not outdir and len(args) > 0: outdir = args[-1] if indir == outdir: raise ValueError("indir and outdir should not be the same") os.makedirs(outdir, exist_ok=True) # 获取输入文件,生成参数 args = [[f, outdir, top_n, degenerate] for f in glob(indir)] # for arg in args: # print(arg[0]) # __call_func__(arg) with Pool(len(args)) as p: list(tqdm(p.imap(__call_func__, args), total=len(args))) if __name__ == '__main__': main()