PSP - 解决 ESMFold 推理长序列蛋白质结构的显存溢出问题

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/134709211

IMG

使用 ESMFold 推理长序列 (Seq. Len. > 1500) 时,导致显存不足,需要设置 chunk_size 参数,实现长序列蛋白质的结构预测,避免显存溢出。

ESMFold:https://github.com/facebookresearch/esm

测试 ESM 单条 Case,序列长度 1543 较长,即:

python -u myscripts/esmfold_infer.py \
-f fasta_446/7WY5_R1543.fasta \
-o mydata/test_gpcr/

A100 显存溢出:

Tried to allocate 54.74 GiB (GPU 0; 79.32 GiB total capacity; 73.53 GiB already allocated; 3.94 GiB free; 74.24 GiB reserved in total by PyTorch)

解决显存问题,参考:Out of memory - upper limit on sequence length?

关键参数:chunk-size

Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). Equivalent to running a for loop over chunks of of each dimension. Lower values will result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. Default: None.

将轴向注意力计算分块 (Chunks) ,将内存使用量从 O(L^2) 减少到 O(L)。 相当于在每个维度的块上运行 for 循环。 较低的值将导致内存使用量降低,但代价是速度。 建议值:128、64、32。默认值:无。

关键参数:max-tokens-per-batch,即 max_tokens_per_batch

Maximum number of tokens per gpu forward-pass. This will group shorter sequences together for batched prediction. Lowering this can help with out of memory issues, if these occur on short sequences.

每个 GPU 前向传递的最大令牌数。 这会将较短的序列分组在一起以进行批量预测。 如果内存不足问题发生在短序列上,降低此值可以帮助解决这些问题。

chunk-size 设置成 128,问题解决,即:

max_len = 1200
# A100 最多支持 1200 长度的序列
if len(seq) > max_len:
    chunk_size = 128
    print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")
    self.model.set_chunk_size(chunk_size)
else:
    self.model.set_chunk_size(None)
    
with torch.no_grad():
    output = self.model.infer_pdb(seq)

推理脚本:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/7/5
"""
import argparse
import os
import sys
import time
from pathlib import Path

import torch
from tqdm import tqdm

import esm

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)


from myutils.protein_utils import get_seq_from_fasta
from myutils.project_utils import time_elapsed, mkdir_if_not_exist, traverse_dir_files


class EsmfoldInfer(object):
    """
    ESMFold的推理类
    """
    def __init__(self):
        print("[Info] 开始加载 ESMFold 模型!")
        s_time = time.time()
        model = esm.pretrained.esmfold_v1()
        self.model = model.eval().cuda()
        print(f"[Info] vocab: {self.model.esm_dict.to_dict()}")
        # 耗时: 00:01:13.264272
        print(f"[Info] 完成加载 ESMFold 模型! 耗时: {time_elapsed(s_time, time.time())}")

    def predict_seq(self, seq, out_path, is_log=True):
        """
        预测序列
        """
        print(f"[Info] seq_len: {len(seq)}")
        max_len = 1200
        # A100 最多支持 1200 长度的序列
        if len(seq) > max_len:
            chunk_size = 128
            print(f"[Warning] seq length is too long! {len(seq)} > {max_len}, chunk_size: {chunk_size}")
            self.model.set_chunk_size(chunk_size)
        else:
    		self.model.set_chunk_size(None)

        s_time = time.time()
        with torch.no_grad():
            output = self.model.infer_pdb(seq)
        seq_len = len(seq)
        if is_log:
            print(f"[Info] 完成推理,链长 {seq_len}, 耗时: {time_elapsed(s_time, time.time())}, "
                  f"平均序列耗时: {(time.time() - s_time) / seq_len}")
        with open(out_path, "w") as f:
            f.write(output)
        if is_log:
            print(f"[Info] 输出: {output}")

    def predict_fasta_dir(self, input_path, output_dir):
        """
        预测 FASTA 文件夹
        """
        print(f"[Info] input_path: {input_path}")
        print(f"[Info] output_dir: {output_dir}")
        assert os.path.isfile(input_path) or os.path.isdir(input_path)
        mkdir_if_not_exist(output_dir)

        if os.path.isdir(input_path):
            path_list = traverse_dir_files(input_path, ext="fasta")
        elif os.path.isfile(input_path):
            path_list = [input_path]
        else:
            raise Exception(f"Error input: {input_path}")

        print(f"[Info] Fasta 数量: {len(path_list)}")
        s_time = time.time()
        for path in tqdm(path_list, desc="[Info] fasta"):
            fasta_name = os.path.basename(path).split(".")[0]
            output_fasta_dir = os.path.join(output_dir, fasta_name)
            mkdir_if_not_exist(output_fasta_dir)

            pdb_name = os.path.basename(path).replace("fasta", "pdb")
            output_pdb_path = os.path.join(output_fasta_dir, pdb_name)

            if os.path.exists(output_pdb_path):
                print(f"[Info] 已预测完成: {output_pdb_path}")
                continue
            seqs, _ = get_seq_from_fasta(path)
            seq = seqs[0]
            self.predict_seq(seq, output_pdb_path, is_log=False)
        print(f"[Info] 全部运行完成: {output_dir}, 耗时: {time_elapsed(s_time, time.time())}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-f",
        "--fasta-input",
        type=Path,
        required=True,
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        type=Path,
        required=True
    )
    args = parser.parse_args()

    fasta_input = str(args.fasta_input)
    output_dir = str(args.output_dir)
    mkdir_if_not_exist(output_dir)

    ei = EsmfoldInfer()
    ei.predict_fasta_dir(fasta_input, output_dir)


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/203906.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue大屏自适应终极解决方案

v-scale-screenv-scale-screen是一个大屏自适应组件,在实际业务中,我们常用图表来做数据统计,数据展示,数据可视化等比较直观的方式来达到一目了然的数据查看,但在大屏开发过程中,常会因为适配不同屏幕而感…

前端三大MV*模式:MVC、mvvm、mvp模式介绍

MVC(同步通信为主):Model、View、Controller MVP(异步通信为主):Model、View、Presenter MVVM(异步通信为主):Model、View、ViewModel mvc模式介绍 MVC(Model–View–Controller)模式是软件…

C语言——编写程序,判断从键盘输入字符的类型(大写字母、小写字母、数字、其他四类)

#define _CRT_SECURE_NO_WARNINGS 1#include <ctype.h> #include <stdio.h> int main() { char c;printf("请输入一个字符: \n");scanf("%c",&c);if (isupper(c)) {printf("这是一个大写字母\n");} else if (islower(c)) {pr…

解决tomcat 启动 , 中文乱码问题

解决tomcat 启动 , 中文乱码问题. 第一步找到server.xml, 找到连接器, 添加 URIEncoding"UTF-8" 注意是英文的引号. 第二步, 找到 logging.properties , 在其中找到 第三步,启动服务, 观察现象,亲测有效.

社区内涝积水监测系统作用,改善社区积水

随着社区化进程的加速&#xff0c;社区基础设施的重要性日益凸显。在这个背景下&#xff0c;社区生命线和内涝积水监测系统成为了关注的焦点。它们在维护社区安全&#xff0c;特别是在应对暴雨等极端天气条件下&#xff0c;发挥着至关重要的作用。 WITBEE万宾时刻关注社区内涝积…

Cookie要怎么测试?

前言 Cookie是一种用于在web应用程序中存储用户特定信息的方法&#xff0c;可以让网站服务器把少量数据存储到客户端的硬盘或内存&#xff0c;或是从客户端的硬盘读取数据。Cookie的测试是指对Cookie的功能、性能、安全性、兼容性等方面进行验证的过程。 同时&#xff0c;在这…

arthas使用

官方文档 Github: https://github.com/alibaba/arthas 文档: https://arthas.aliyun.com/doc/ Arthas 是一款线上监控诊断产品&#xff0c;通过全局视角实时查看应用 load、内存、gc、线程的状态信息&#xff0c;并能在不修改应用代码的情况下&#xff0c;对业务问题进行诊断…

民安智库(第三方市场调查公司):专业调研引领某月饼生产商企业发展

在中国的传统佳节中&#xff0c;月饼是一种重要的节日食品&#xff0c;也是送礼的首选。某月饼生产商一直以来以其高品质、独特口味的月饼而备受消费者喜爱。为了更好地了解消费者对产品的满意度&#xff0c;该月饼生产商决定委托民安智库&#xff08;湖北知名满意度测评公司&a…

eBay需要添加什么卡可以付费?

前言 最近很多朋友不管是做eBay易贝卖家还是在eBay海淘下单购物的都或多或少遇到无法用卡付费的问题&#xff0c;甚至很多朋友之前明明可以用卡去付费&#xff0c;第一次可以第二次却不行了&#xff0c;想不到吧&#xff1f;这eBay平台还有这种骚操作&#xff0c;那到底用什么…

Spring-Mybatis源码解析--手写代码实现Spring整合Mybatis

文章目录 前言一、引入&#xff1a;二、准备工作&#xff1a;2.1 引入依赖2.2 数据源的文件&#xff1a;2.1 数据源&#xff1a; 2.3 业务文件&#xff1a; 三、整合的实现&#xff1a;3.1 xxxMapper 接口的扫描&#xff1a;3.2 xxxMapper 接口代理对象的生成&#xff1a;3.2 S…

vr智慧党建展厅超强参与感增强党员群众认同感、归属感

党建教育与VR虚拟现实技术的结合&#xff0c;是顺应现代信息化发展趋势的要求&#xff0c;不仅打破了传统党建教育的束缚&#xff0c;还丰富了党建宣传教育的渠道&#xff0c;党建教育VR云课堂平台是基于深圳华锐视点自主研发的VR云课堂平台中去体验各种VR党建教育软件或者视频…

水利遥测终端机RTU的重要作用

水利遥测终端机RTU是一种在水利行业中广泛应用的设备&#xff0c;它利用先进的传感技术和远程通信技术&#xff0c;实时监测和采集水利系统的各项指标数据。它的出现不仅提高了水利行业的运行效率&#xff0c;同时也使水利管理更加科学和精细化。 ■水利遥测终端机RTU具备高精度…

如何有效的进行 E2E

一、前言 本文作者介绍了什么是E2E测试以及E2E测试测什么&#xff0c;并从对于被测系统、测试用例、测试自动化工具、测试者四个方面的要求&#xff0c;介绍了如何保证E2E测试有效性&#xff0c;干货满满&#xff0c;值得学习。 二、什么是E2E测试 相信每一个对自动化测试感…

记录一次前后端传参方式不一致异常

✅作者简介&#xff1a;大家好&#xff0c;我是Leo&#xff0c;热爱Java后端开发者&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;Leo的博客 &#x1f49e;当前专栏&#xff1a; 报错以及Bug ✨特色专栏&#xff1a; …

同城按摩理疗APP小程序开发制作流程;

同城按摩理疗APP小程序开发制作流程&#xff1b; 开发同城按摩理疗APP小程序&#xff0c;首先需要进行市场调研&#xff0c;深入了解用户需求&#xff0c;明确小程序的定位和服务对象。接着&#xff0c;根据需求分析结果&#xff0c;制定详细的设计方案和开发计划。然后&#…

python中的enumerate函数

enumerate函数是Python内置builtins模块中的一个函数&#xff0c;用于将一个可迭代对象转换为一个索引-元素对的枚举对象&#xff0c;从而方便地同时获得索引和元素&#xff0c;并在循环迭代中使用。 enumerate函数的语法格式为&#xff1a;enumerate(iterable, start0) itera…

Sui与阿联酋科技孵化器Hub71合作支持生态项目建设,扩大全球影响力

近日&#xff0c;总部位于阿联酋&#xff08; United Arab Emirates &#xff0c;UAE&#xff09;的科技孵化器Hub71宣布与Mysten Labs合作&#xff0c;将支持Sui上的新项目。通过本次合作&#xff0c;孵化项目的开发者们不仅可以获得Mysten Labs的技术专业知识和支持&#xff…

基于Java SSM框架+Vue实现垃圾分类网站系统项目【项目源码+论文说明】计算机毕业设计

基于java的SSM框架Vue实现垃圾分类网站系统演示 摘要 本论文主要论述了如何使用JAVA语言开发一个垃圾分类网站 &#xff0c;本系统将严格按照软件开发流程进行各个阶段的工作&#xff0c;采用B/S架构&#xff0c;面向对象编程思想进行项目开发。在引言中&#xff0c;作者将论述…

北斗卫星助力乡村治理,走进数字化新时代

北斗卫星助力乡村治理&#xff0c;走进数字化新时代 随着国家对乡村治理越来越重视&#xff0c;为了进一步提升乡村治理水平&#xff0c;我国已经启动了全面建设现代化强国的大计划&#xff0c;其中数字化成为了重要的一环。而北斗卫星作为我国自主研制的卫星导航系统&#xff…

【带头学C++】----- 九、类和对象 ---- 9.1 类和对象的基本概念----(9.1.1---9.1.3)

目录 9.1 类和对象的基本概念 9.1.1 类的封装性 9.1.2 定义类的步骤和方法 9.1.3 设计一个学生类 Student 9.1 类和对象的基本概念 9.1.1 类的封装性 类是一种用户自定义的数据类型&#xff0c;它定义了一组数据成员和成员函数。类可以看作是一个模板或者蓝图&#xff0c;用…