环境配置
%%capture captured_output 此乃 Jupyter Notebook 中的一个魔法命令,其作用在于捕获后续单元格中的输出,并将之存储于变量 captured_output 当中,而非直接于输出区域予以显示。如此一来,便可隐匿某些可能存在的输出信息。
通过 pip 命令来卸载已安装的 mindspore 库(倘若其存在的话),其中 -y 这一参数意味着自动确认卸载操作。接着,安装 mindspore 库,且其版本指定为 2.2.14 。随后,安装与自然语言处理相关的 mindnlp 库,以及用于中文分词的 jieba 库。此外,还设置了环境变量 HF_ENDPOINT ,将其值设定为 https://hf-mirror.com 。在代码的执行进程中,环境变量能够被相关的程序或者库所运用。在此处,这一操作或许是为了明确某个与 HF 相关(HF 可能是某种特定的框架、服务或者资源)的端点或者镜像地址,从而使得后续的操作能够准确无误地访问或者使用相应的资源。
代码如下:
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
运行结果:
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting mindnlp
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/72/37/ef313c23fd587c3d1f46b0741c98235aecdfd93b4d6d446376f3db6a552c/mindnlp-0.3.1-py3-none-any.whl (5.7 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 13.6 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: mindspore in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.2.14)
Requirement already satisfied: tqdm in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (4.66.4)
Requirement already satisfied: requests in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindnlp) (2.32.3)
Collecting datasets (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/2d/963b266bb8f88492d5ab4232d74292af8beb5b6fdae97902df9e284d4c32/datasets-2.20.0-py3-none-any.whl (547 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 547.8/547.8 kB 5.2 MB/s eta 0:00:00a 0:00:01
Collecting evaluate (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c2/d6/ff9baefc8fc679dcd9eb21b29da3ef10c81aa36be630a7ae78e4611588e1/evaluate-0.4.2-py3-none-any.whl (84 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 84.1/84.1 kB 11.5 MB/s eta 0:00:00
Collecting tokenizers (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/26/139bd2371228a0e203da7b3e3eddcb02f45b2b7edd91df00e342e4b55e13/tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (3.6 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.6/3.6 MB 13.0 MB/s eta 0:00:00a 0:00:01
Collecting safetensors (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/02/28e6280ed0f1bde89eed644b80f2ece4e5ae212dc9ee70d7f56fadc93602/safetensors-0.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 15.6 MB/s eta 0:00:00a 0:00:01
Collecting sentencepiece (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a3/69/e96ef68261fa5b82379fdedb325ceaf1d353c6e839ec346d8244e0da5f2f/sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.3 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 18.8 MB/s eta 0:00:00a 0:00:01
Collecting regex (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/70/70/fea4865c89a841432497d1abbfd53878513b55c6543245fabe31cf8df0b8/regex-2024.5.15-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (774 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 774.7/774.7 kB 15.4 MB/s eta 0:00:00a 0:00:01
Collecting addict (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl (3.8 kB)
Collecting ml-dtypes (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/50/96/13d7c3cc82d5ef597279216cf56ff461f8b57e7096a3ef10246a83ca80c0/ml_dtypes-0.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (2.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 11.5 MB/s eta 0:00:00a 0:00:01
Collecting pyctcdecode (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a5/8a/93e2118411ae5e861d4f4ce65578c62e85d0f1d9cb389bd63bd57130604e/pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)
Collecting jieba (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/cb/18eeb235f833b726522d7ebed54f2278ce28ba9438e3135ab0278d9792a2/jieba-0.42.1.tar.gz (19.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 19.2/19.2 MB 11.0 MB/s eta 0:00:0000:0100:01
Preparing metadata (setup.py) ... done
Collecting pytest==7.2.0 (from mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/68/a5eb36c3a8540594b6035e6cdae40c1ef1b6a2bfacbecc3d1a544583c078/pytest-7.2.0-py3-none-any.whl (316 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.8/316.8 kB 12.8 MB/s eta 0:00:00
Requirement already satisfied: attrs>=19.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2.0)
Requirement already satisfied: iniconfig in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.0)
Requirement already satisfied: packaging in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (23.2)
Requirement already satisfied: pluggy<2.0,>=0.12 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.5.0)
Requirement already satisfied: exceptiongroup>=1.0.0rc8 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (1.2.0)
Requirement already satisfied: tomli>=1.0.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pytest==7.2.0->mindnlp) (2.0.1)
Requirement already satisfied: filelock in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (3.15.3)
Requirement already satisfied: numpy>=1.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (1.26.4)
Collecting pyarrow>=15.0.0 (from datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/87/60/cc0645eb4ef73f88847e40a7f9d238bae6b7409d6c1f6a5d200d8ade1f09/pyarrow-16.1.0-cp39-cp39-manylinux_2_28_aarch64.whl (38.1 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 38.1/38.1 MB 10.3 MB/s eta 0:00:0000:0100:01
Collecting pyarrow-hotfix (from datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/e4/f4/9ec2222f5f5f8ea04f66f184caafd991a39c8782e31f5b0266f101cb68ca/pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.3.8)
Requirement already satisfied: pandas in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (2.2.2)
Collecting xxhash (from datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7c/b9/93f860969093d5d1c4fa60c75ca351b212560de68f33dc0da04c89b7dc1b/xxhash-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (220 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 220.6/220.6 kB 22.5 MB/s eta 0:00:00
Collecting multiprocess (from datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl (133 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 133.4/133.4 kB 15.2 MB/s eta 0:00:00
Collecting fsspec<=2024.5.0,>=2023.1.0 (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ba/a3/16e9fe32187e9c8bc7f9b7bcd9728529faa725231a0c96f2f98714ff2fc5/fsspec-2024.5.0-py3-none-any.whl (316 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 316.1/316.1 kB 16.8 MB/s eta 0:00:00
Collecting aiohttp (from datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/eb/45/eebe8d2215328434f33ccb44a05d2741ff7ed4b96b56ca507e2ecf598b73/aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (1.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2 MB 7.4 MB/s eta 0:00:00a 0:00:01
Requirement already satisfied: huggingface-hub>=0.21.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (0.23.4)
Requirement already satisfied: pyyaml>=5.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from datasets->mindnlp) (6.0.1)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from requests->mindnlp) (2024.6.2)
Requirement already satisfied: protobuf>=3.13.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.27.1)
Requirement already satisfied: asttokens>=2.0.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (2.0.5)
Requirement already satisfied: pillow>=6.2.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (10.3.0)
Requirement already satisfied: scipy>=1.5.4 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.13.1)
Requirement already satisfied: psutil>=5.6.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (5.9.0)
Requirement already satisfied: astunparse>=1.6.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from mindspore->mindnlp) (1.6.3)
Collecting pygtrie<3.0,>=2.1 (from pyctcdecode->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/ec/cd/bd196b2cf014afb1009de8b0f05ecd54011d881944e62763f3c1b1e8ef37/pygtrie-2.5.0-py3-none-any.whl (25 kB)
Collecting hypothesis<7,>=6.14 (from pyctcdecode->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/79/21/83a2c2472c75e6443cd2068a51bf5b9d07505702adae5fb887fe0318629e/hypothesis-6.108.1-py3-none-any.whl (465 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 465.0/465.0 kB 14.0 MB/s eta 0:00:00
Requirement already satisfied: six in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from asttokens>=2.0.4->mindspore->mindnlp) (1.16.0)
Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from astunparse>=1.6.3->mindspore->mindnlp) (0.43.0)
Collecting aiosignal>=1.1.2 (from aiohttp->datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/15/172af60c7e150a1d88ecc832f2590721166ae41eab582172fe1e9844eab4/frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (239 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 239.4/239.4 kB 14.4 MB/s eta 0:00:00
Collecting multidict<7.0,>=4.5 (from aiohttp->datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/d0/10/2ff646c471e84af25fe8111985ffb8ec85a3f6e1ade8643bfcfcc0f4d2b1/multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (125 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 125.9/125.9 kB 14.1 MB/s eta 0:00:00
Collecting yarl<2.0,>=1.0 (from aiohttp->datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/c6/d6/5b30ae1d8a13104ee2ceb649f28f2db5ad42afbd5697fd0fc61528bb112c/yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl (300 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 300.9/300.9 kB 14.1 MB/s eta 0:00:00
Collecting async-timeout<5.0,>=4.0 (from aiohttp->datasets->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from huggingface-hub>=0.21.2->datasets->mindnlp) (4.11.0)
Collecting sortedcontainers<3.0.0,>=2.1.0 (from hypothesis<7,>=6.14->pyctcdecode->mindnlp)
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl (29 kB)
Requirement already satisfied: python-dateutil>=2.8.2 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (from pandas->datasets->mindnlp) (2024.1)
Building wheels for collected packages: jieba
Building wheel for jieba (setup.py) ... done
Created wheel for jieba: filename=jieba-0.42.1-py3-none-any.whl size=19314459 sha256=c13a2538f852cf7ab019a1996950690b65a203218e95481705edd590f7135a5e
Stored in directory: /home/nginx/.cache/pip/wheels/1a/76/68/b6d79c4db704bb18d54f6a73ab551185f4711f9730c0c15d97
Successfully built jieba
Installing collected packages: sortedcontainers, sentencepiece, pygtrie, jieba, addict, xxhash, safetensors, regex, pytest, pyarrow-hotfix, pyarrow, multiprocess, multidict, ml-dtypes, hypothesis, fsspec, frozenlist, async-timeout, yarl, pyctcdecode, aiosignal, tokenizers, aiohttp, datasets, evaluate, mindnlp
Attempting uninstall: pytest
Found existing installation: pytest 8.0.0
Uninstalling pytest-8.0.0:
Successfully uninstalled pytest-8.0.0
Attempting uninstall: fsspec
Found existing installation: fsspec 2024.6.0
Uninstalling fsspec-2024.6.0:
Successfully uninstalled fsspec-2024.6.0
Successfully installed addict-2.4.0 aiohttp-3.9.5 aiosignal-1.3.1 async-timeout-4.0.3 datasets-2.20.0 evaluate-0.4.2 frozenlist-1.4.1 fsspec-2024.5.0 hypothesis-6.108.1 jieba-0.42.1 mindnlp-0.3.1 ml-dtypes-0.4.0 multidict-6.0.5 multiprocess-0.70.16 pyarrow-16.1.0 pyarrow-hotfix-0.6 pyctcdecode-0.5.0 pygtrie-2.5.0 pytest-7.2.0 regex-2024.5.15 safetensors-0.4.3 sentencepiece-0.2.0 sortedcontainers-2.4.0 tokenizers-0.19.1 xxhash-3.4.1 yarl-1.9.4
[notice] A new release of pip is available: 24.1 -> 24.1.2
[notice] To update, run: python -m pip install --upgrade pip
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: jieba in /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages (0.42.1)
[notice] A new release of pip is available: 24.1 -> 24.1.2
[notice] To update, run: python -m pip install --upgrade pip
env: HF_ENDPOINT=https://hf-mirror.com
数据准备
在数据准备阶段,所进行的一系列操作皆是为后续的模型训练与评估开展前期的数据处理相关工作。这一阶段的主要任务是对数据进行有效的整理、清洗、转换和划分,以便为模型的训练提供高质量、结构清晰且易于处理的数据,从而为后续模型能够准确且高效地进行训练和评估奠定坚实基础。
代码如下:
import os
import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn
from mindnlp.dataset import load_dataset
from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()
分析:其主要运用了若干库和模块来开展处理数据以及与模型训练相关的操作。
首先,导入了一系列必需的库,例如 os、mindspore 以及和数据处理及模型有关的模块。
随后,借助 load_dataset 函数载入了 imdb 数据集,并将其划分为 train 和 test 这两个部分。
紧接着,利用 get_dataset_size 方法获取到了 imdb_train 数据集的规模大小。
运行结果:
处理数据集
定义了一个名为“process_dataset”的函数,其旨在针对给定的数据集实施预处理流程。该函数的核心要义在于对输入的数据集展开分词、类型转换以及批量处理等一系列预处理操作,从而为后续将其应用于模型训练或其他相关任务做好充分准备。
代码如下:
import numpy as np
def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
is_ascend = mindspore.get_context('device_target') == 'Ascend'
def tokenize(text):
if is_ascend:
tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
else:
tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
return tokenized['input_ids'], tokenized['attention_mask']
if shuffle:
dataset = dataset.shuffle(batch_size)
# map dataset
dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
# batch dataset
if is_ascend:
dataset = dataset.batch(batch_size)
else:
dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
'attention_mask': (None, 0)})
return dataset
分析:首先,引入了 numpy 库。
函数接受数据集 dataset、分词器 tokenizer 以及一些可选参数,如最大序列长度 max_seq_len、批量大小 batch_size 和是否打乱数据 shuffle 。
通过判断运行环境是否为 Ascend 来决定在分词时是否进行填充操作。
如果 shuffle 为 True,则对数据集进行打乱,打乱的批量大小为 batch_size 。
然后,对数据集进行映射操作,将文本列 text 通过 tokenize 函数处理,得到 input_ids 和 attention_mask 列。接着将标签列 label 转换为指定的数据类型。
最后,根据运行环境是否为 Ascend ,以不同的方式对数据集进行批量处理,包括填充策略。
GPTTokenizer
进行与自然语言处理中的分词器(GPTTokenizer)相关的操作。
代码如下:
from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')
# add sepcial token: <PAD>
special_tokens_dict = {
"bos_token": "<bos>",
"eos_token": "<eos>",
"pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
分析:首先,从 mindnlp.transformers 模块导入了 GPTTokenizer 类。
然后,通过 from_pretrained 方法从预训练的模型 'openai-gpt' 初始化一个 GPTTokenizer 对象 gpt_tokenizer 。
接着,定义了一个名为 special_tokens_dict 的字典,其中包含了特殊标记 bos_token(起始标记)、eos_token(结束标记)和 pad_token(填充标记),并为它们指定了相应的值。
最后,使用 add_special_tokens 方法将定义的特殊标记添加到 gpt_tokenizer 中,并将添加的特殊标记数量存储在 num_added_toks 变量中。
数据集分割和处理
代码如下:
# 将名为 `imdb_train` 的数据集按照比例 0.7 和 0.3 分割为两个数据集,分别命名为 `imdb_train` 和 `imdb_val`
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
# 使用 `process_dataset` 函数对分割出来的 `imdb_train` 数据集进行处理,处理时启用了随机打乱(`shuffle=True`),并将处理结果存储在 `dataset_train` 中
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
# 使用 `process_dataset` 函数对分割出来的 `imdb_val` 数据集进行处理,未启用随机打乱,并将处理结果存储在 `dataset_val` 中
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
# 使用 `process_dataset` 函数对名为 `imdb_test` 的数据集进行处理,并将处理结果存储在 `dataset_test` 中
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 从 `dataset_train` 创建的元组迭代器中获取下一个元素
next(dataset_train.create_tuple_iterator())
分析:如果 imdb_train 原本包含 100 个样本,按照 0.7 的比例分割后,imdb_train 可能包含 70 个样本,imdb_val 包含 30 个样本。然后经过 process_dataset 函数的处理,得到可用于后续操作的 dataset_train 、dataset_val 和 dataset_test 。最后,通过 next(dataset_train.create_tuple_iterator()) 尝试获取 dataset_train 迭代器中的第一个元素。
训练和评估
在实际应用中,对基于 GPT 的序列分类任务进行微调训练和评估。训练过程中会按照设置的参数进行,如学习率、轮数等,并通过回调函数保存训练过程中的检查点和最佳模型。评估过程则使用测试数据集来评估训练好的模型的性能。
代码如下
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam
# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)
trainer = Trainer(network=model, train_dataset=dataset_train,
eval_dataset=dataset_train, metrics=metric,
epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
jit=False)
trainer.run(tgt_columns="labels")
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")
分析:首先从'mindnlp.transformers'模块导入'GPTForSequenceClassification'类,从'mindspore.experimental.optim'模块导入'Adam'优化器,
然后从预训练的模型'openai-gpt'加载'GPTForSequenceClassification'模型,并设置分类的类别数为 2。将模型配置中的填充标记 ID 设置为之前定义的分词器的填充标记 ID。创建训练器对象,指定模型、训练数据集、评估数据集、评估指标、训练轮数、优化器、回调函数以及是否使用即时编译(jit=False),运行训练器,目标列是"labels",创建评估器对象,指定模型、评估数据集和评估指标,
最后运行评估器,目标列是"labels"。
运行结果: