从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

从零开发短视频电商 在AWS上SageMaker部署模型自定义日志输入和输出示例

怎么部署自定义模型请看:从零开发短视频电商 在AWS上用SageMaker部署自定义模型

  • 都是huaggingface上的模型或者fine-tune后的。

为了适配jumpstart上部署的模型的http输入输出,我在自定义模型中自定义了适配的输入输出,可以做到兼容适配

code/inference.py

  • 容器的原始代码入口:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/handler_service.py
  • 默认支持的decode和encode:https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
  • 可以用这个在sagemaker上使用jupyterlab:https://github.com/huggingface/notebooks/blob/main/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb

我们自定义的逻辑如下

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import json
import logging
// --------- 这块
logger = logging.getLogger()
logger.setLevel(logging.INFO)
// 自定义http输入,可以适配不同的content_type ,打印输入的日志
// 源码参见下面的 preprocess
def input_fn(input_data, content_type):
    logger.info(f"laker input_data {input_data} and content_type {content_type}")
    if content_type == "application/json":
        request = json.loads(input_data)
    elif content_type == "application/x-text":
        request = {"inputs": input_data.decode('utf-8')}
    else:
        request = {"inputs": input_data} 
    logger.info(f"laker input_fn request {request} ")
    return request
// 自定义输出
def output_fn(prediction, accept):
    return encode_json(prediction)  

// 来自https://github.com/aws/sagemaker-huggingface-inference-toolkit/blob/80634b30703e8e9525db8b7128b05f713f42f9dc/src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py#L102C1-L113C6
    
class _JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif hasattr(obj, "tolist"):
            return obj.tolist()
        elif isinstance(obj, datetime.datetime):
            return obj.__str__()
        elif isinstance(obj, Image.Image):
            with BytesIO() as out:
                obj.save(out, format="PNG")
                png_string = out.getvalue()
                return base64.b64encode(png_string).decode("utf-8")
        else:
            return super(_JSONEncoder, self).default(obj)
        
def encode_json(content):
    """
    encodes json with custom `JSONEncoder`
    """
    return json.dumps(
        content,
        ensure_ascii=False,
        allow_nan=False,
        indent=None,
        cls=_JSONEncoder,
        separators=(",", ":"),
    )
// --------- 这块  end ---

# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer
    
    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    
    # return dictonary, which will be json serializable
    return {"embedding": sentence_embeddings[0].tolist()}
import logging
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder

logger = logging.getLogger(__name__)

 def preprocess(self, input_data, content_type, context=None):
        """
        The preprocess handler is responsible for deserializing the input data into
        an object for prediction, can handle JSON.
        The preprocess handler can be overridden for data or feature transformation.

        Args:
            input_data: the request payload serialized in the content_type format.
            content_type: the request content_type.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            decoded_input_data (dict): deserialized input_data into a Python dictonary.
        """
        # raises en error when using zero-shot-classification or table-question-answering, not possible due to nested properties
        if (
            os.environ.get("HF_TASK", None) == "zero-shot-classification"
            or os.environ.get("HF_TASK", None) == "table-question-answering"
        ) and content_type == content_types.CSV:
            raise PredictionException(
                f"content type {content_type} not support with {os.environ.get('HF_TASK', 'unknown task')}, use different content_type",
                400,
            )

        decoded_input_data = decoder_encoder.decode(input_data, content_type)
        return decoded_input_data
    
    
    

   logger.info(
        f"param1 {batch_size} and param2 {sequence_length}"
    )



def predict(self, data, model, context=None):
        """The predict handler is responsible for model predictions. Calls the `__call__` method of the provided `Pipeline`
        on decoded_input_data deserialized in input_fn. Runs prediction on GPU if is available.
        The predict handler can be overridden to implement the model inference.

        Args:
            data (dict): deserialized decoded_input_data returned by the input_fn
            model : Model returned by the `load` method or if it is a custom module `model_fn`.
            context (obj): metadata on the incoming request data (default: None).

        Returns:
            obj (dict): prediction result.
        """

        # pop inputs for pipeline
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)

        # pass inputs with all kwargs in data
        if parameters is not None:
            prediction = model(inputs, **parameters)
        else:
            prediction = model(inputs)
        return prediction

    def postprocess(self, prediction, accept, context=None):
        """
        The postprocess handler is responsible for serializing the prediction result to
        the desired accept type, can handle JSON.
        The postprocess handler can be overridden for inference response transformation.

        Args:
            prediction (dict): a prediction result from predict.
            accept (str): type which the output data needs to be serialized.
            context (obj): metadata on the incoming request data (default: None).
        Returns: output data serialized
        """
        return decoder_encoder.encode(prediction, accept)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/264618.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springMVC-与spring整合

一、基本介绍 在项目开发中,spring管理的 Service和 Respository,SrpingMVC管理 Controller和ControllerAdvice,分工明确 当我们同时配置application.xml, springDispatcherServlet-servlet.xml , 那么注解的对象会被创建两次, 故…

2023 下半年系统架构设计师学习进度

文章目录 复习计划:每周350分钟第一周(339分钟)第二周(265分钟)第三周(171分钟)第四周(214分钟)第五周(274分钟)第六周(191分钟&#…

初识Stable Diffusion

界面选项解读 这是在趋动云上部署的Stable Diffusion txt2img prompt (1)分割符号:使用逗号 , 用于分割词缀,且有一定权重排序功能,逗号前权重高,逗号后权重低 (2)建议的通用范式…

【Java JMM】编译和优化

1 前端编译 在 Java 技术下, “编译期” 是一个比较含糊的表述, 因为它可能指的是 前端编译器 (“编译器的前端” 更准确一些) 把 *.java 文件转变成 *.class 文件的过程Java 虚拟机的即时编译器 (常称 JIT 编译器, Just In Time Compiler) 运行期把字节码转变成本地机器码的过…

《Python》面试常问:深拷贝、浅拷贝、赋值之间的关系(附可变与不可变)【用图文讲清楚!】

背景 想必大家面试或者平时学习经常遇到问python的深拷贝、浅拷贝和赋值之间的区别了吧?看网上的文章很多写的比较抽象,小白接收的难度有点大,于是乎也想自己整个文章出来供参考 可变与不可变 讲深拷贝和浅拷贝之前想讲讲什么是可变数据类型…

Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析

Pytorch常用的函数(五)np.meshgrid()和torch.meshgrid()函数解析 我们知道torch.meshgrid()函数的功能是生成网格,可以用于生成坐标; 在numpy中也有一样的函数np.meshgrid(),但是用法不太一样,我们直接上代码进行解释。 1、两者…

如何在Window系统下搭建Nginx服务器环境并部署前端项目

1.下载并安装Nginx 在nginx官网nginx: download 下载稳定版本至自己想要的目录。 解压后进入目录 2.启动Nginx服务器 启动方式有两种: (1)直接进入nginx安装目录下,双击nginx.exe运行,此时命令行窗口一闪而过&…

浏览器 cookie 的原理(详)

目录 1,cookie 的出现2,cookie 的组成浏览器自动发送 cookie 的条件 3,设置 cookie3.1,服务端设置3.1,客户端设置3.3,删除 cookie 4,使用流程总结 整理和测试花了很大时间,如果对你有…

python调用GPT API

每次让gpt给我生成一个调用api的程序时,他经常会调用以前的一些api的方法,导致我的程序运行错误,所以这期记录一下使用新的方法区调用api 参考网址 Migration Guide,这里简要地概括了一下新版本做了哪些更改 OpenAI Python API l…

引领汽车营销新趋势,3DCAT实时云渲染助力汽车三维可视化

当前,汽车产业发展正从电动化的上半场,向智能化的下半场迈进。除了车机技术体验的智能化之外,观车体验的智能化也不容忽视。 这是因为,随着数字化、智能化、个性化的趋势,消费者对汽车的需求和期待也越来越高&#xf…

2016年第五届数学建模国际赛小美赛B题直达地铁线路解题全过程文档及程序

2016年第五届数学建模国际赛小美赛 B题 直达地铁线路 原题再现: 在目前的大都市地铁网络中,在两个相距遥远的车站之间运送乘客通常需要很长时间。我们可以建议在两个长途车站之间设置直达班车,以节省长途乘客的时间。   第一部分&#xf…

Qt的简单游戏实现提供完整代码

文章目录 1 项目简介2 项目基本配置2.1 创建项目2.2 添加资源 3 主场景3.1 设置游戏主场景配置3.2 设置背景图片3.3 创建开始按钮3.4 开始按钮跳跃特效实现3.5 创建选择关卡场景3.6 点击开始按钮进入选择关卡场景 4 选择关卡场景4.1场景基本设置4.2 背景设置4.3 创建返回按钮4.…

Java面向对象(初级)

面向对象编程(基础) 面向对象编程(OOP)是一种编程范式,它强调程序设计是围绕对象、类和方法构建的。在面向对象编程中,程序被组织为一组对象,这些对象可以互相传递消息。面向对象编程的核心概念包括封装、继承和多态。…

2023.12.21 关于 Redis 常用数据结构 和 单线程模型

目录 各数据结构具体编码方式 查看 key 对应 value 的编码方式 Reids 单线程模型 经典面试题 IO 多路复用 Redis 常用数据结构 Redis 中所有的 key 均为 String 类型,而不同的是 value 的数据类型却有很多种以下介绍 5 种 value 常见的数据类型 注意&#xff1…

阿里云 ACK One 新特性:多集群网关,帮您快速构建同城容灾系统

云布道师 近日,阿里云分布式云容器平台 ACK One[1]发布“多集群网关”[2](ACK One Multi-cluster Gateways)新特性,这是 ACK One 面向多云、多集群场景提供的云原生网关,用于对多集群南北向流量进行统一管理。 基于 …

虚拟机的下载、安装(模拟出服务器)

下载 vmware workstation(收费的虚拟机) 下载vbox 网址:Oracle VM VirtualBox(免费的虚拟机) 以下选择一个下载即可,建议下载vbox,因为是免费的。安装的时候默认下一步即可(路径最好…

hiveserver负载均衡配置

一.安装nginx 参数我的另一篇文章:https://mp.csdn.net/mp_blog/creation/editor/135152478 二.配置nginx服务参数 worker_processes 1; events { worker_connections 1024; } stream { upstream hiveserver2 { # least_conn; # 使用最少连接路由…

八大排序算法@直接插入排序(C语言版本)

目录 直接插入排序概念算法思想代码实现核心算法:直接插入排序的算法实现: 特性总结 直接插入排序 概念 算法思想 把待排序的记录按其关键码值的大小逐个插入到一个已经排好序的有序序列中,直到所有的记录插入完为止,得到一个新…

【Spring实战】配置多数据源

文章目录 1. 配置数据源信息2. 创建第一个数据源3. 创建第二个数据源4. 创建启动类及查询方法5. 启动服务6. 创建表及做数据7. 查询验证8. 详细代码总结 通过上一节的介绍,我们已经知道了如何使用 Spring 进行数据源的配置以及应用。在一些复杂的应用中,…

文档 - - - Docsify文档创建

目录 1. Docsify 介绍2. 创建 Docsify 项目2.1 安装 Node.js2.1 安装 docsfiy-cli2.3 初始化项目2.4 运行项目2.5 使用 Python 运行项目(扩展,不推荐有bug) 3. 配置 Docsify 项目3.1 修改等待加载文字3.2 添加网站 ico 图标3.3 创建新页面写文…