昇思25天学习打卡营第12天|MindSpore 助力下的 GPT2:数据集加载处理及模型全攻略

环境配置


        %%capture captured_output 此乃 Jupyter Notebook 中的一个魔法命令,其作用在于捕获后续单元格中的输出,并将之存储于变量 captured_output 当中,而非直接于输出区域予以显示。如此一来,便可隐匿某些可能存在的输出信息。

        通过 pip 命令来卸载已安装的 mindspore 库(倘若其存在的话),其中 -y 这一参数意味着自动确认卸载操作。接着,安装 mindspore 库,且其版本指定为 2.2.14 。随后,安装与自然语言处理相关的 mindnlp 库,以及用于中文分词的 jieba 库。此外,还设置了环境变量 HF_ENDPOINT ,将其值设定为 https://hf-mirror.com 。在代码的执行进程中,环境变量能够被相关的程序或者库所运用。在此处,这一操作或许是为了明确某个与 HF 相关(HF 可能是某种特定的框架、服务或者资源)的端点或者镜像地址,从而使得后续的操作能够准确无误地访问或者使用相应的资源。

        代码如下:

%%capture captured_output  
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号  
!pip uninstall mindspore -y  
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14  
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`  
!pip install mindnlp  
!pip install jieba  
%env HF_ENDPOINT=https://hf-mirror.com 

        运行结果:

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple  
Collecting mindnlp  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/72/37/ef313c23fd587c3d1f46b0741c98235aecdfd93b4d6d446376f3db6a552c/mindnlp-0.3.1-py3-none-any.whl (5.7 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 13.6 MB/s eta 0:00:00a 0:00:01  
Requirement already satisfied: mindspore in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.2.14)  
Requirement already satisfied: tqdm in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.66.4)  
Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3)  
Collecting datasets (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/2d/963b266bb8f88492d5ab4232d74292af8beb5b6fdae97902df9e284d4c32/datasets-2.20.0-py3-none-any.whl (547 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 547.8/547.8 kB 5.2 MB/s eta 0:00:00a 0:00:01  
Collecting evaluate (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/d6/ff9baefc8fc679dcd9eb21b29da3ef10c81aa36be630a7ae78e4611588e1/evaluate-0.4.2-py3-none-any.whl (84 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.1/84.1 kB 11.5 MB/s eta 0:00:00  
Collecting tokenizers (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/26/139bd2371228a0e203da7b3e3eddcb02f45b2b7edd91df00e342e4b55e13/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (3.6 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 13.0 MB/s eta 0:00:00a 0:00:01  
Collecting safetensors (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/02/28e6280ed0f1bde89eed644b80f2ece4e5ae212dc9ee70d7f56fadc93602/safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 15.6 MB/s eta 0:00:00a 0:00:01  
Collecting sentencepiece (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/69/e96ef68261fa5b82379fdedb325ceaf1d353c6e839ec346d8244e0da5f2f/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.3 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 18.8 MB/s eta 0:00:00a 0:00:01  
Collecting regex (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/70/70/fea4865c89a841432497d1abbfd53878513b55c6543245fabe31cf8df0b8/regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (774 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 774.7/774.7 kB 15.4 MB/s eta 0:00:00a 0:00:01  
Collecting addict (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl (3.8 kB)  
Collecting ml-dtypes (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/96/13d7c3cc82d5ef597279216cf56ff461f8b57e7096a3ef10246a83ca80c0/ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (2.2 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 11.5 MB/s eta 0:00:00a 0:00:01  
Collecting pyctcdecode (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a5/8a/93e2118411ae5e861d4f4ce65578c62e85d0f1d9cb389bd63bd57130604e/pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)  
Collecting jieba (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/cb/18eeb235f833b726522d7ebed54f2278ce28ba9438e3135ab0278d9792a2/jieba-0.42.1.tar.gz (19.2 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.2/19.2 MB 11.0 MB/s eta 0:00:0000:0100:01  
  Preparing metadata (setup.py) ... done  
Collecting pytest==7.2.0 (from mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/68/a5eb36c3a8540594b6035e6cdae40c1ef1b6a2bfacbecc3d1a544583c078/pytest-7.2.0-py3-none-any.whl (316 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.8/316.8 kB 12.8 MB/s eta 0:00:00  
Requirement already satisfied: attrs>=19.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2.0)  
Requirement already satisfied: iniconfig in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.0)  
Requirement already satisfied: packaging in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2)  
Requirement already satisfied: pluggy<2.0,>=0.12 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0)  
Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0)  
Requirement already satisfied: tomli>=1.0.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.1)  
Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.15.3)  
Requirement already satisfied: numpy>=1.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (1.26.4)  
Collecting pyarrow>=15.0.0 (from datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/60/cc0645eb4ef73f88847e40a7f9d238bae6b7409d6c1f6a5d200d8ade1f09/pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl (38.1 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 38.1/38.1 MB 10.3 MB/s eta 0:00:0000:0100:01  
Collecting pyarrow-hotfix (from datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)  
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.3.8)  
Requirement already satisfied: pandas in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.2)  
Collecting xxhash (from datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7c/b9/93f860969093d5d1c4fa60c75ca351b212560de68f33dc0da04c89b7dc1b/xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (220 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 220.6/220.6 kB 22.5 MB/s eta 0:00:00  
Collecting multiprocess (from datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl (133 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.4/133.4 kB 15.2 MB/s eta 0:00:00  
Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/a3/16e9fe32187e9c8bc7f9b7bcd9728529faa725231a0c96f2f98714ff2fc5/fsspec-2024.5.0-py3-none-any.whl (316 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.1/316.1 kB 16.8 MB/s eta 0:00:00  
Collecting aiohttp (from datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/eb/45/eebe8d2215328434f33ccb44a05d2741ff7ed4b96b56ca507e2ecf598b73/aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 7.4 MB/s eta 0:00:00a 0:00:01  
Requirement already satisfied: huggingface-hub>=0.21.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.23.4)  
Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1)  
Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2)  
Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.7)  
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.2.2)  
Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2024.6.2)  
Requirement already satisfied: protobuf>=3.13.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.27.1)  
Requirement already satisfied: asttokens>=2.0.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (2.0.5)  
Requirement already satisfied: pillow>=6.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (10.3.0)  
Requirement already satisfied: scipy>=1.5.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.13.1)  
Requirement already satisfied: psutil>=5.6.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.9.0)  
Requirement already satisfied: astunparse>=1.6.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.6.3)  
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/cd/bd196b2cf014afb1009de8b0f05ecd54011d881944e62763f3c1b1e8ef37/pygtrie-2.5.0-py3-none-any.whl (25 kB)  
Collecting hypothesis<7,>=6.14 (from pyctcdecode->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/79/21/83a2c2472c75e6443cd2068a51bf5b9d07505702adae5fb887fe0318629e/hypothesis-6.108.1-py3-none-any.whl (465 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 465.0/465.0 kB 14.0 MB/s eta 0:00:00  
Requirement already satisfied: six in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore->mindnlp) (1.16.0)  
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore->mindnlp) (0.43.0)  
Collecting aiosignal>=1.1.2 (from aiohttp->datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl (7.6 kB)  
Collecting frozenlist>=1.1.1 (from aiohttp->datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/15/172af60c7e150a1d88ecc832f2590721166ae41eab582172fe1e9844eab4/frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (239 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 239.4/239.4 kB 14.4 MB/s eta 0:00:00  
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d0/10/2ff646c471e84af25fe8111985ffb8ec85a3f6e1ade8643bfcfcc0f4d2b1/multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (125 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.9/125.9 kB 14.1 MB/s eta 0:00:00  
Collecting yarl<2.0,>=1.0 (from aiohttp->datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/d6/5b30ae1d8a13104ee2ceb649f28f2db5ad42afbd5697fd0fc61528bb112c/yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (300 kB)  
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 300.9/300.9 kB 14.1 MB/s eta 0:00:00  
Collecting async-timeout<5.0,>=4.0 (from aiohttp->datasets->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl (5.7 kB)  
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub>=0.21.2->datasets->mindnlp) (4.11.0)  
Collecting sortedcontainers<3.0.0,>=2.1.0 (from hypothesis<7,>=6.14->pyctcdecode->mindnlp)  
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)  
Requirement already satisfied: python-dateutil>=2.8.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0)  
Requirement already satisfied: pytz>=2020.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)  
Requirement already satisfied: tzdata>=2022.7 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)  
Building wheels for collected packages: jieba  
  Building wheel for jieba (setup.py) ... done  
  Created wheel for jieba: filename=jieba-0.42.1-py3-none-any.whl size=19314459 sha256=c13a2538f852cf7ab019a1996950690b65a203218e95481705edd590f7135a5e  
  Stored in directory: /home/nginx/.cache/pip/wheels/1a/76/68/b6d79c4db704bb18d54f6a73ab551185f4711f9730c0c15d97  
Successfully built jieba  
Installing collected packages: sortedcontainers, sentencepiece, pygtrie, jieba, addict, xxhash, safetensors, regex, pytest, pyarrow-hotfix, pyarrow, multiprocess, multidict, ml-dtypes, hypothesis, fsspec, frozenlist, async-timeout, yarl, pyctcdecode, aiosignal, tokenizers, aiohttp, datasets, evaluate, mindnlp  
  Attempting uninstall: pytest  
    Found existing installation: pytest 8.0.0  
    Uninstalling pytest-8.0.0:  
      Successfully uninstalled pytest-8.0.0  
  Attempting uninstall: fsspec  
    Found existing installation: fsspec 2024.6.0  
    Uninstalling fsspec-2024.6.0:  
      Successfully uninstalled fsspec-2024.6.0  
Successfully installed addict-2.4.0 aiohttp-3.9.5 aiosignal-1.3.1 async-timeout-4.0.3 datasets-2.20.0 evaluate-0.4.2 frozenlist-1.4.1 fsspec-2024.5.0 hypothesis-6.108.1 jieba-0.42.1 mindnlp-0.3.1 ml-dtypes-0.4.0 multidict-6.0.5 multiprocess-0.70.16 pyarrow-16.1.0 pyarrow-hotfix-0.6 pyctcdecode-0.5.0 pygtrie-2.5.0 pytest-7.2.0 regex-2024.5.15 safetensors-0.4.3 sentencepiece-0.2.0 sortedcontainers-2.4.0 tokenizers-0.19.1 xxhash-3.4.1 yarl-1.9.4  
  
[notice] A new release of pip is available: 24.1 -> 24.1.2  
[notice] To update, run: python -m pip install --upgrade pip  
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple  
Requirement already satisfied: jieba in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (0.42.1)  
  
[notice] A new release of pip is available: 24.1 -> 24.1.2  
[notice] To update, run: python -m pip install --upgrade pip  
env: HF_ENDPOINT=https://hf-mirror.com  

数据准备


        在数据准备阶段,所进行的一系列操作皆是为后续的模型训练与评估开展前期的数据处理相关工作。这一阶段的主要任务是对数据进行有效的整理、清洗、转换和划分,以便为模型的训练提供高质量、结构清晰且易于处理的数据,从而为后续模型能够准确且高效地进行训练和评估奠定坚实基础。

        代码如下:

import os  
import mindspore  
from mindspore.dataset import text, GeneratorDataset, transforms  
from mindspore import nn  
from mindnlp.dataset import load_dataset  
from mindnlp._legacy.engine import Trainer, Evaluator  
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback  
from mindnlp._legacy.metrics import Accuracy  
imdb_ds = load_dataset('imdb', split=['train', 'test'])  
imdb_train = imdb_ds['train']  
imdb_test = imdb_ds['test']  
imdb_train.get_dataset_size()  

        分析:其主要运用了若干库和模块来开展处理数据以及与模型训练相关的操作。

        首先,导入了一系列必需的库,例如 os、mindspore 以及和数据处理及模型有关的模块。

        随后,借助 load_dataset 函数载入了 imdb 数据集,并将其划分为 train 和 test 这两个部分。

        紧接着,利用 get_dataset_size 方法获取到了 imdb_train 数据集的规模大小。

        运行结果:

处理数据集


        定义了一个名为“process_dataset”的函数,其旨在针对给定的数据集实施预处理流程。该函数的核心要义在于对输入的数据集展开分词、类型转换以及批量处理等一系列预处理操作,从而为后续将其应用于模型训练或其他相关任务做好充分准备。

        代码如下:

import numpy as np  
def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):  
    is_ascend = mindspore.get_context('device_target') == 'Ascend'  
    def tokenize(text):  
        if is_ascend:  
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)  
        else:  
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)  
        return tokenized['input_ids'], tokenized['attention_mask']  
  
    if shuffle:  
        dataset = dataset.shuffle(batch_size)  
  
    # map dataset  
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])  
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")  
    # batch dataset  
    if is_ascend:  
        dataset = dataset.batch(batch_size)  
    else:  
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),  
                                                             'attention_mask': (None, 0)})  
    return dataset  

        分析:首先,引入了 numpy 库。

        函数接受数据集 dataset、分词器 tokenizer 以及一些可选参数,如最大序列长度 max_seq_len、批量大小 batch_size 和是否打乱数据 shuffle 。

        通过判断运行环境是否为 Ascend 来决定在分词时是否进行填充操作。

        如果 shuffle 为 True,则对数据集进行打乱,打乱的批量大小为 batch_size 。

        然后,对数据集进行映射操作,将文本列 text 通过 tokenize 函数处理,得到 input_ids 和 attention_mask 列。接着将标签列 label 转换为指定的数据类型。

        最后,根据运行环境是否为 Ascend ,以不同的方式对数据集进行批量处理,包括填充策略。

GPTTokenizer


        进行与自然语言处理中的分词器(GPTTokenizer)相关的操作。

        代码如下:

from mindnlp.transformers import GPTTokenizer  
# tokenizer  
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')  
# add sepcial token: <PAD>  
special_tokens_dict = {  
    "bos_token": "<bos>",  
    "eos_token": "<eos>",  
    "pad_token": "<pad>",  
}  
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)  

        分析:首先,从 mindnlp.transformers 模块导入了 GPTTokenizer 类。

        然后,通过 from_pretrained 方法从预训练的模型 'openai-gpt' 初始化一个 GPTTokenizer 对象 gpt_tokenizer 。

        接着,定义了一个名为 special_tokens_dict 的字典,其中包含了特殊标记 bos_token(起始标记)、eos_token(结束标记)和 pad_token(填充标记),并为它们指定了相应的值。

        最后,使用 add_special_tokens 方法将定义的特殊标记添加到 gpt_tokenizer 中,并将添加的特殊标记数量存储在 num_added_toks 变量中。

数据集分割和处理


        代码如下:

# 将名为 `imdb_train` 的数据集按照比例 0.7 和 0.3 分割为两个数据集,分别命名为 `imdb_train` 和 `imdb_val`  
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])  
# 使用 `process_dataset` 函数对分割出来的 `imdb_train` 数据集进行处理,处理时启用了随机打乱(`shuffle=True`),并将处理结果存储在 `dataset_train` 中  
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)  
# 使用 `process_dataset` 函数对分割出来的 `imdb_val` 数据集进行处理,未启用随机打乱,并将处理结果存储在 `dataset_val` 中  
dataset_val = process_dataset(imdb_val, gpt_tokenizer)  
# 使用 `process_dataset` 函数对名为 `imdb_test` 的数据集进行处理,并将处理结果存储在 `dataset_test` 中  
dataset_test = process_dataset(imdb_test, gpt_tokenizer)  
# 从 `dataset_train` 创建的元组迭代器中获取下一个元素  
next(dataset_train.create_tuple_iterator())  

        分析:如果 imdb_train 原本包含 100 个样本,按照 0.7 的比例分割后,imdb_train 可能包含 70 个样本,imdb_val 包含 30 个样本。然后经过 process_dataset 函数的处理,得到可用于后续操作的 dataset_train 、dataset_val 和 dataset_test 。最后,通过 next(dataset_train.create_tuple_iterator()) 尝试获取 dataset_train 迭代器中的第一个元素。

训练和评估


        在实际应用中,对基于 GPT 的序列分类任务进行微调训练和评估。训练过程中会按照设置的参数进行,如学习率、轮数等,并通过回调函数保存训练过程中的检查点和最佳模型。评估过程则使用测试数据集来评估训练好的模型的性能。

        代码如下

from mindnlp.transformers import GPTForSequenceClassification  
from mindspore.experimental.optim import Adam  
# set bert config and define parameters for training  
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)  
model.config.pad_token_id = gpt_tokenizer.pad_token_id  
model.resize_token_embeddings(model.config.vocab_size + 3)  
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)  
metric = Accuracy()  
# define callbacks to save checkpoints  
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)  
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)  
trainer = Trainer(network=model, train_dataset=dataset_train,  
                  eval_dataset=dataset_train, metrics=metric,  
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],  
                  jit=False)  
trainer.run(tgt_columns="labels")  
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)  
evaluator.run(tgt_columns="labels")  

         分析:首先从'mindnlp.transformers'模块导入'GPTForSequenceClassification'类,从'mindspore.experimental.optim'模块导入'Adam'优化器,

        然后从预训练的模型'openai-gpt'加载'GPTForSequenceClassification'模型,并设置分类的类别数为 2。将模型配置中的填充标记 ID 设置为之前定义的分词器的填充标记 ID。创建训练器对象,指定模型、训练数据集、评估数据集、评估指标、训练轮数、优化器、回调函数以及是否使用即时编译(jit=False),运行训练器,目标列是"labels",创建评估器对象,指定模型、评估数据集和评估指标,

        最后运行评估器,目标列是"labels"。

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/799024.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(九)-无人机服务区分离

引言 本文是3GPP TR 22.829 V17.1.0技术报告&#xff0c;专注于无人机&#xff08;UAV&#xff09;在3GPP系统中的增强支持。文章提出了多个无人机应用场景&#xff0c;分析了相应的能力要求&#xff0c;并建议了新的服务级别要求和关键性能指标&#xff08;KPIs&#xff09;。…

小程序里面使用vant ui中的vant-field组件,如何使得输入框自动获取焦点

//.wxml <van-fieldmodel:value"{{ userName }}"placeholder"请输入学号"focus"{{focusUserName}}"/>// .js this.setData({focusUserName: true});vant-field

钡铼ARMxy控制器在智能网关中的应用

随着IoT物联网技术的飞速发展&#xff0c;智能网关作为连接感知层与网络层的枢纽&#xff0c;可以实现感知网络和通信网络以及不同类型感知网络之间的协议转换。钡铼技术的ARMxy系列控制器凭借其高性能、低功耗和高度灵活性的特点&#xff0c;在智能网关中发挥了关键作用&#…

RPC与服务的注册发现

文章目录 1. 什么是远程过程调用(RPC)?2. RPC的流程3. RPC实践4. RPC与REST的区别4.1 RPC与REST的相似之处4.2 RPC与REST的架构原则4.3 RPC与REST的主要区别 5. RPC与服务发现5.1 以zookeeper为服务注册中心5.2 以etcd为服务注册中心 6. 小结参考 1. 什么是远程过程调用(RPC)?…

大语言模型诞生过程剖析

过程图如下 &#x1f4da; 第一步&#xff1a;海量文本的无监督学习 得到基座大模型&#x1f389; &#x1f50d; 原料&#xff1a;首先&#xff0c;我们需要海量的文本数据&#xff0c;这些数据可以来自互联网上的各种语料库&#xff0c;包括书籍、新闻、科学论文、社交媒体帖…

<数据集>光伏板缺陷检测数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;2400张 标注数量(xml文件个数)&#xff1a;2400 标注数量(txt文件个数)&#xff1a;2400 标注类别数&#xff1a;4 标注类别名称&#xff1a;[Crack,Grid,Spot] 序号类别名称图片数框数1Crack8688922Grid8248843S…

全栈智能家居系统设计方案:STM32+Linux+多协议(MQTT、Zigbee、Z-Wave)通信+云平台集成

1. 项目概述 随着物联网技术的快速发展,智能家居系统正在成为现代生活中不可或缺的一部分。本文介绍了一个基于STM32微控制器和Linux系统的智能家居解决方案,涵盖了硬件设计、软件架构、通信协议以及云平台集成等方面。 该系统具有以下特点: 采用STM32作为终端设备的控制核心…

springboot3——项目部署

springboot的项目开发完了&#xff0c;怎么样把他放到服务器上或者生产环境上让他运行起来跑起来。就要牵扯到项目部署&#xff0c;打包的方式了。 springboot支持jar和war: 打jar包&#xff1a;默认方式&#xff0c;项目开发完打个jar包&#xff0c;通过命令把jar包起起来就…

汇川ST 实现分拣

//初始化 IF init FALSE THEN// 初始化init : 1 ;//45 Y数组 BOOL[8] [OFF发料Y OFF分拣Y OFF送料Y OFF取料Y OFF摆取Y OFF摆放Y OFF升降Y OFF夹料Y] [OFF发料Y OFF分拣Y OFF送料Y OFF取料Y OFF摆取Y OFF摆放Y OFF升降Y OFF夹料Y] 不保持 私有 Y0(*Y数组[0] BOOL OFF 发料…

MySQL 中的几种锁

MySQL 中的锁 #按锁粒度如何划分? 按锁粒度划分的话&#xff0c;MySQL 的锁有&#xff1a; 表锁&#xff1a;开销小&#xff0c;加锁快&#xff1b;锁定力度大&#xff0c;发生锁冲突概率高&#xff0c;并发度最低;不会出现死锁。行锁&#xff1a;开销大&#xff0c;加锁慢…

unity宏编译版本

在写c程序的时候我们通常可以用不同的宏定义来控制不同版本的编译内容&#xff0c;最近有个需求就是根据需要编译一个完全体验版本&#xff0c;就想到了用vs的那套方法。经过研究发现unity也有类似的控制方法。 注意这里设置完后要点击右下的应用&#xff0c;我起先就没有设置…

7/13 - 7/15

vo.setId(rs.getLong("id"))什么意思&#xff1f; vo.setId(rs.getLong("id")); 这行代码是在Java中使用ResultSet对象&#xff08;通常用于从数据库中检索数据&#xff09;获取一个名为"id"的列&#xff0c;并将其作为long类型设置为一个对象…

深度学习基础:Numpy 数组包

数组基础 在使用导入 Numpy 时&#xff0c;通常给其一个别名 “np”&#xff0c;即 import numpy as np 。 数据类型 整数类型数组与浮点类型数组 为了克服列表的缺点&#xff0c;一个 Numpy 数组只容纳一种数据类型&#xff0c;以节约内存。为方便起见&#xff0c;可将 Nu…

简洁实用的原创度检测工具AntiPlagiarism NET 4.132

AntiPlagiarism NET是一个适用于Windows的程序&#xff0c;它允许您检查文本的唯一性和从不同Internet来源借用的存在。使用AntiPlagiarism NET&#xff0c;您可以&#xff1a; 将程序用于不同的目的该程序适用于学生、教师、记者、文案作者和其他需要检查其文本或其他作者文本…

FPGA 项目菜单功能比较

为了帮助您更好地理解每个FPGA功能模块的实用场合、区别和特点&#xff0c;以下是详细的比较&#xff1a; 功能模块实用场合区别特点FPGA I/O自动控制系统、数据采集系统直接与FPGA板卡上的物理端口交互&#xff0c;配置和使用外部I/O设备灵活配置输入输出端口&#xff0c;支持…

开放式蓝牙耳机哪家强?五款值得入手的款式推荐!

当前&#xff0c;耳机技术迎来颠覆性飞跃&#xff0c;开放式蓝牙耳机横空出世&#xff0c;以其卓越的音质体验与无可比拟的听力保护优势&#xff0c;重塑了耳机行业的标准。其非入耳设计&#xff0c;既保证了音质的纯粹传达&#xff0c;又有效避免了长时间佩戴对耳朵的潜在伤害…

板级调试小助手(2)ZYNQ自定义IP核构建属于自己的DDS外设

一、前言 在上期文章中讲述了小助手的系统结构和原理。在PYNQ的框架开发中&#xff0c;我们一般可以将PL端当做PS端的一个外设&#xff0c;通过读写寄存器的方式来操作外设的功能&#xff0c;就类似于在开发ARM和DSP中操作外设一样&#xff0c;不同时的是&#xff0c;我们可以通…

AIGC率超标?掌握论文去AI痕迹的高效策略

随着 AI 技术迅猛发展&#xff0c;各种AI辅助论文写作的工具层出不穷&#xff01; 为了防止有人利用AI工具进行论文代写&#xff0c;在最新的学位法中已经明确规定“已经获得学位者&#xff0c;在获得该学位过程中如有人工智能代写等学术不端行为&#xff0c;经学位评定委员会…

智慧公厕系统实现人性化与节能化的完美结合

在当今社会&#xff0c;科技的飞速发展正不断改变着我们的生活方式&#xff0c;公厕也不例外。智慧公厕系统的出现&#xff0c;不仅提升了人们的使用体验&#xff0c;更实现了人性化与节能化的完美结合&#xff0c;为城市公共服务带来了全新的变革。 一、人性化&#xff0c;是智…

清朝嘉庆二十五年(1820年)地图数据

我们在《中国历史行政区划连续变化数据》一文中&#xff0c;为你分享了中国历史行政区划连续变化地图数据。 现在再为你分享清朝嘉庆二十五年&#xff08;1820年&#xff09;的地图数据&#xff0c;该数据对于研究历史的朋友应该比较有用&#xff0c;请在文末查看领取方式。 …