Transformers实战——Datasets板块

文章目录

  • 一、基本使用
    • 1.加载在线数据集
    • 2.加载数据集合集中的某一项任务
    • 3.按照数据集划分进行加载
    • 4.查看数据集
      • 查看一条数据集
      • 查看多条数据集
      • 查看数据集里面的某个字段
      • 查看所有的列
      • 查看所有特征
    • 5.数据集划分
    • 6.数据选取与过滤
    • 7.数据映射
    • 8.保存与加载
  • 二、加载本地数据集
    • 1.直接加载文件作为数据集
    • 2.加载文件夹内全部文件作为数据集
    • 3.通过预先加载的其他格式转换加载数据集
    • 4.Dataset with DataCollator

!pip install datasets
from datasets import load_dataset

一、基本使用

1.加载在线数据集

datasets = load_dataset("madao33/new-title-chinese")
datasets
'''
DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['title', 'content'],
        num_rows: 1679
    })
})
'''

2.加载数据集合集中的某一项任务

boolq_dataset = load_dataset("super_glue", "boolq")
boolq_dataset
'''
DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3245
    })
})
'''

3.按照数据集划分进行加载

dataset = load_dataset("madao33/new-title-chinese", split="train")
dataset
'''
Dataset({
    features: ['title', 'content'],
    num_rows: 5850
})
'''
  • 可以取切片
dataset = load_dataset("madao33/new-title-chinese", split="train[10:100]")
dataset
'''
Dataset({
    features: ['title', 'content'],
    num_rows: 90
})
'''
  • 可以取百分比
dataset = load_dataset("madao33/new-title-chinese", split="train[:50%]")
dataset
'''
Dataset({
    features: ['title', 'content'],
    num_rows: 2925
})
'''

  • 可以取多个
dataset = load_dataset("madao33/new-title-chinese", split=["train[:50%]", "train[50%:]"])
dataset
'''
[Dataset({
     features: ['title', 'content'],
     num_rows: 2925
 }),
 Dataset({
     features: ['title', 'content'],
     num_rows: 2925
 })]
'''
boolq_dataset = load_dataset("super_glue", "boolq", split=["train[:50%]", "train[50%:]"])
boolq_dataset
'''
[Dataset({
     features: ['question', 'passage', 'idx', 'label'],
     num_rows: 4714
 }),
 Dataset({
     features: ['question', 'passage', 'idx', 'label'],
     num_rows: 4713
 })]
'''

4.查看数据集

datasets = load_dataset("madao33/new-title-chinese")
datasets
'''
DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['title', 'content'],
        num_rows: 1679
    })
})
'''

查看一条数据集

datasets["train"][0]

查看多条数据集

datasets["train"][:2]
'''

查看数据集里面的某个字段

datasets["train"]["title"][:5]

# 或者这样也可以
datasets["train"][:5]['title']

查看所有的列

datasets["train"].column_names
'''
['title', 'content']
'''

查看所有特征

datasets["train"].features
'''
{'title': Value(dtype='string', id=None),
 'content': Value(dtype='string', id=None)}
'''

5.数据集划分

dataset = datasets["train"]
dataset.train_test_split(test_size=0.1, seed=3407)
'''
DatasetDict({
    train: Dataset({
        features: ['title', 'content'],
        num_rows: 5265
    })
    test: Dataset({
        features: ['title', 'content'],
        num_rows: 585
    })
})
'''
  • 分类数据集可以按照比例划分(分布均衡),即单看某一个类别所占的比例 train和test中应该是一样的 比如0类在train中占0.3,那test中0类占比也是 0.3
dataset = boolq_dataset["train"]
dataset.train_test_split(test_size=0.1, stratify_by_column="label") 
'''
DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 8484
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 943
    })
})
'''

# 解释分布均衡
dataset['train']['label'].count(1) / len(dataset['train'])
'''
0.6230551626591231
'''
dataset['test']['label'].count(1) / len(dataset['test'])
'''
0.6230551626591231
'''

6.数据选取与过滤

# 选取
datasets["train"].select([0, 1])
'''
Dataset({
    features: ['title', 'content'],
    num_rows: 2
})
'''
# 过滤
filter_dataset = datasets["train"].filter(lambda example: "中国" in example["title"])
filter_dataset["title"][:5]

7.数据映射

def add_prefix(example):
    example["title"] = 'Prefix: ' + example["title"]
    return example
prefix_dataset = datasets.map(add_prefix)
prefix_dataset["train"][:10]["title"]

from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

def preprocess_function(example):
    model_inputs = tokenizer(example["content"], max_length=512, truncation=True)
    labels = tokenizer(example["title"], max_length=32, truncation=True)
    # label就是title编码的结果
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
processed_datasets = datasets.map(preprocess_function)
processed_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1679
    })
})
'''
  • 使用多进程
    • 注意需要多一个参数 tokenizer=tokenizer
def preprocess_function(example, tokenizer=tokenizer):
    model_inputs = tokenizer(example["content"], max_length=512, truncation=True)
    labels = tokenizer(example["title"], max_length=32, truncation=True)
    # label就是title编码的结果
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

processed_datasets = datasets.map(preprocess_function, num_proc=4)
processed_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1679
    })
})
'''
  • TokenizerFast 可以使用 batched=True加速映射过程
processed_datasets = datasets.map(preprocess_function, batched=True)
processed_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['title', 'content', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1679
    })
})
'''

  • 删除多余列
processed_datasets = datasets.map(preprocess_function, batched=True, 
                                  remove_columns=datasets["train"].column_names)
processed_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1679
    })
})
'''

8.保存与加载

processed_datasets.save_to_disk("./processed_data")

image.png


from datasets import load_from_disk


processed_datasets = load_from_disk("./processed_data")
processed_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 5850
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1679
    })
})
'''

二、加载本地数据集

1.直接加载文件作为数据集

  • 这里加 split="train"是因为加载本地数据集会默认将数据集视为 train
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7766
})
'''
  • 也可以用 Dataset
from datasets import Dataset

dataset = Dataset.from_csv("./ChnSentiCorp_htl_all.csv")
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7766
})
'''

2.加载文件夹内全部文件作为数据集

image.png

dataset = load_dataset("csv", data_dir="/content/all_data", split='train')
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 15532
})
'''
  • 指定加载文件夹中哪些文件
dataset = load_dataset("csv",
                       data_files=['/content/all_data/ChnSentiCorp_htl_all.csv',
                                   '/content/all_data/ChnSentiCorp_htl_all2.csv'], 
                       split='train')
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 15532
})
'''

  • cache_dir:构建的数据集缓存目录,方便下次快速加载
dataset = load_dataset("csv", 
                       data_files=['/content/all_data/ChnSentiCorp_htl_all.csv',
                                  '/content/all_data/ChnSentiCorp_htl_all2.csv'], 
                       split='train',
                       cache_dir='dir')
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 15532
})
'''

image.png


3.通过预先加载的其他格式转换加载数据集

import pandas as pd

data = pd.read_csv("./ChnSentiCorp_htl_all.csv")
data.head()
dataset = Dataset.from_pandas(data)
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7766
})
'''
  • List格式的数据需要内嵌{},明确数据字段
# List格式的数据需要内嵌{},明确数据字段
data = [{"text": "abc"}, {"text": "def"}]
# data = ["abc", "def"] # 报错
Dataset.from_list(data)
'''
Dataset({
    features: ['text'],
    num_rows: 2
})
'''

4.Dataset with DataCollator

from transformers import  DataCollatorWithPadding
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split='train')
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7765
})
'''
def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples
tokenized_dataset = dataset.map(process_function, 
                                batched=True, 
                                remove_columns=dataset.column_names)
tokenized_dataset
'''
Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 7765
})
'''
print(tokenized_dataset[:3])
'''
{'input_ids': [[101, 6655, 4895, 2335, 3763, 1062, 6662, 6772, 6818, 117, 852, 3221, 1062, 769, 2900, 4850, 679, 2190, 117, 1963, 3362, 3221, 107, 5918, 7355, 5296, 107, 4638, 6413, 117, 833, 7478, 2382, 7937, 4172, 119, 2456, 6379, 4500, 1166, 4638, 6662, 5296, 119, 2791, 7313, 6772, 711, 5042, 1296, 119, 102], [101, 1555, 1218, 1920, 2414, 2791, 8024, 2791, 7313, 2523, 1920, 8024, 2414, 3300, 100, 2160, 8024, 3146, 860, 2697, 6230, 5307, 3845, 2141, 2669, 679, 7231, 106, 102], [101, 3193, 7623, 1922, 2345, 8024, 3187, 6389, 1343, 1914, 2208, 782, 8024, 6929, 6804, 738, 679, 1217, 7608, 1501, 4638, 511, 6983, 2421, 2418, 6421, 7028, 6228, 671, 678, 6821, 702, 7309, 7579, 749, 511, 2791, 7313, 3315, 6716, 2523, 1962, 511, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 
'labels': [1, 1, 1]}
'''
collator = DataCollatorWithPadding(tokenizer=tokenizer)
  • 动态填充,每个 batch_size长度不一样
from torch.utils.data import DataLoader

dl = DataLoader(tokenized_dataset, batch_size=4, collate_fn=collator, shuffle=True)
num = 0
for batch in dl:
    print(batch["input_ids"].size())
    num += 1
    if num > 10:
        break
'''
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 115])
torch.Size([4, 128])
torch.Size([4, 117])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 128])
torch.Size([4, 127])
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/159035.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MARKDOWN

新的改变 我们对Markdown编辑器进行了一些功能拓展与语法支持,除了标准的Markdown编辑器功能,我们增加了如下几点新功能,帮助你用它写博客: 全新的界面设计 ,将会带来全新的写作体验;在创作中心设置你喜爱…

vscode 配置 lua

https://luabinaries.sourceforge.net/ 官网链接 主要分为4个步骤 下载压缩包,然后解压配置系统环境变量配置vscode的插件测试 这里你可以选择用户变量或者系统环境变量都行。 不推荐空格的原因是 再配置插件的时候含空格的路径 会出错,原因是空格会断…

YOLOv5 配置C2模块构造新模型

🍨 本文为[🔗365天深度学习训练营学习记录博客 🍦 参考文章:365天深度学习训练营 🍖 原作者:[K同学啊] 🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb4…

html使用天地图写一个地图列表

一、效果图&#xff1a; 点击左侧地址列表&#xff0c;右侧地图跟着改变。 二、代码实现&#xff1a; 一进入页面时&#xff0c;通过body调用onLoad"onLoad()"函数&#xff0c;确保地图正常显示。 <body onLoad"onLoad()"><!--左侧代码-->…

电磁场与电磁波part2--电磁场的基本规律

1、电流连续性方程的微分形式 表明时变电流场是有散场&#xff0c;电流线是由电荷随时间变化的地方发出或终止的&#xff0c;在正电荷随时间减小的地方就会发出电流线&#xff0c;在正电荷随时间增加的地方就会终止电流线。 2、任何一个标量函数的梯度再求旋度时恒等于零&#…

【uniapp】华为APP真机运行(novas系列)

依华为手机为例&#xff0c;首先数据线连接电脑&#xff0c;然后在手机上做如下操作&#xff1a; 1&#xff09;打开设置 2&#xff09;设置——关于手机 3&#xff09;连续点击软件版本号&#xff0c;此时手机处于开发者模式 4) 回到设置——系统和更新 5&#xff09;点击开…

全球温度数据下载

1.全球年平均温度下载https://www.ncei.noaa.gov/data/global-summary-of-the-year/archive/ 2.全球月平均气温下载https://www.ncei.noaa.gov/data/global-summary-of-the-month/archive/ 3.全球日平均气温下载https://www.ncei.noaa.gov/data/global-summary-of-the-day/ar…

使用Sqoop命令从Oracle同步数据到Hive,修复数据乱码 %0A的问题

一、创建一张Hive测试表 create table test_oracle_hive(id_code string,phone_code string,status string,create_time string ) partitioned by(partition_date string) ROW FORMAT DELIMITED FIELDS TERMINATED BY ,; 创建分区字段partition_date&#xff0c…

【Effective C++ 笔记】(四)设计与声明

【四】设计与声明 条款18 &#xff1a; 让接口容易被正确使用&#xff0c;不易被误用 Item 18: 让接口容易被正确使用&#xff0c;不易被误用 Make interfaces easy to use correctly and hard to use incorrectly. “让接口容易被正确使用&#xff0c;不易被误用”&#xff0…

C语言日记——调试篇

一、调试调试的基本步骤 发现程序错误的存在 以隔离、消除等方式对错误进行定位 确定错误产生的原因 提出纠正错误的解决办法 对程序错误予以改正&#xff0c;重新测试 二、Debug和Release Debug通常称为调试版本&#xff0c;它包含调试信息&#xff0c;并且不作任何优化…

数据结构C语言之线性表

发现更多计算机知识&#xff0c;欢迎访问Cr不是铬的个人网站 1.1线性表的定义 线性表是具有相同特性的数据元素的一个有限序列 对应的逻辑结构图形&#xff1a; 从线性表的定义中可以看出它的特性&#xff1a; &#xff08;1&#xff09;有穷性&#xff1a;一个线性表中的元…

线程状态及线程之间通信

线程状态概述 当线程被创建并启动以后&#xff0c;它既不是一启动就进入了执行状态&#xff0c;也不是一直处于执行状态。在线程的生命周期中&#xff0c; 有几种状态呢&#xff1f;在 java.lang.Thread.State 这个枚举中给出了六种线程状态&#xff1a; 线程状态 导致状态发生…

Objectarx 使用libcurl请求WebApi

因为开发cad需要请求服务器的数据&#xff0c;再次之前我在服务器搭设了webapi用户传递数据&#xff0c;所以安装了libcurl在objectarx中使用数据。 Open VS2012 x64 Native Tools Command Prompt补充地址&#xff1a; 我在此将相关的引用配置图片&#xff0c;cad里面的应用和…

Linux中的进程等待(超详细)

Linux中的进程等待 1. 进程等待必要性2. 进程等待的方法2.1 wait方法2.2 waitpid方法 3. 获取子进程status4. 具体代码实现 1. 进程等待必要性 我们知道&#xff0c;子进程退出&#xff0c;父进程如果不管不顾&#xff0c;就可能造成‘僵尸进程’的问题&#xff0c;进而造成内…

UE的PlayerController方法Convert Mouse Location To World Space

先上图&#xff1a; Convert Mouse Location To World这是PlayerController对象中很重要的方法。 需要说明的是两个输出值。 第一个是World Location&#xff0c;这是个基于世界空间的位置值&#xff0c;一开始我以为这个值和当前摄像机的位置是重叠的&#xff0c;但是打印出来…

kaggle项目部署

目录 流程详细步骤注意事项 流程 修改模块地址打包项目上传到kaggle Datasets创建code文件&#xff0c;导入数据与项目粘贴train.py文件&#xff0c;调整超参数&#xff0c;选择GPUsave version&#xff0c;后台训练查看训练结果 详细步骤 打开kaggle网站&#xff0c;点击da…

号卡分销管理系统搭建

随着移动互联网的发展&#xff0c;各种手机应用层出不穷&#xff0c;其中包括了很多用于企业管理的软件。号卡系统分销管理软件就是其中的一种。它是一种基于移动互联网的企业管理软件&#xff0c;能够帮助企业进行号卡的分销管理&#xff0c;从而提高企业的效率和竞争力。 …

抖音快手判断性别、年龄自动关注脚本,按键精灵开源代码!

这个是支持抖音和快手两个平台的&#xff0c;可以进入对方主页然后判断对方年龄和性别&#xff0c;符合条件的关注&#xff0c;不符合条件的跳过下一个ID&#xff0c;所以比较精准&#xff0c;当然你可以二次开发加入更多的平台&#xff0c;小红书之类的&#xff0c;仅供学习&a…

Linux(3):Linux 的文件权限与目录配置

把具有相同的账户放入到一个组里面&#xff0c;这个组就是这两个账户的 群组 。在访问资源&#xff08;操作系统中计算机的资源&#xff09;时&#xff0c;可以让这个组里面的所有用户都具有访问权限。 每个账号都可以有多个群组的支持。 在我们Liux 系统当中&#xff0c;默认的…

kibana8.10.4简单使用

1.创建discovery里的日志项目 点击stack management 选择kibana里的数据视图&#xff0c;右上角创建数据视图&#xff0c;输入名称。索引范围。例子 example-* ,匹配以example-开头的所有index。 然后点击 保存数据视图到kibana&#xff0c; 2.Kibana多用户创建及角色权限控…