一、PyTorch加载数据初认识
Dataset:提供一种方式去获取数据及其label
如何获取每一个数据及其label
总共有多少的数据
Dataloader:为后面的网络提供不同的数据形式
数据集
在编译器中导入Dataset
from torch.utils.data import Dataset
可以在jupyter中查看Dataset官方文档:
help(Dataset)
或者
Dataset??
二、Dataset类代码实战
将数据集复制到项目中,命名为dataset,右键拷贝路径。
在pycharm中的控制台运行:
(注意:粘贴完拷贝的路径后需要加上""表示转义字符,共有两个斜杠,否则会报错)
输入img.show()会展示出图片
获取每个图片的地址,创建图片地址列表:
(获得了文件夹的地址后。将文件夹里的数据〔所有照片的路径地址)存入列表里)
可以换成拼接图片路径:
import os
root_dir = "learn_pytorch/dataset/train"
label_dir = "ants"
path = os.path.join(root_dir, label_dir)
测试第一张图片
path = os.path.join(root_dir, label_dir)
img_path = os.listdir(path) # 所有图片地址列表
idx = 0
img_name = img_path[idx] # 第一张图片
img_item_path = os.path.join(root_dir, label_dir, img_name) # 第一张图片地址
read_data.py
from torch.utils.data import Dataset
# import cv2
from PIL import Image
import os # 获取所有图片地址
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "learn_pytorch/dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
将上述代码输入到控制台,然后进行测试:
三、Tensorboard的使用
在编译器中导入
from torch.utils.tensorboard import SummaryWriter
SummaryWriter类使用
在pycharm中查看说明文档方法:可以直接按住ctrl键,点击类名
创建实例对象:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
writer.add_image()
writer.add_scalar()
writer.close()
add_scalar()方法的使用
pycharm中ctrl+‘/’可以注释,注释掉writer.add_image()
add_scalar()方法:
测试:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
# writer.add_image()
# y=x
for i in range(100):
writer.add_scalar("y=x", i, i)
writer.close()
报错。没有安装Tensorboard
安装Tensorboard
在pycharm的Terminal中运行或在anaconda命令行中激活pytorch环境运行
pip install tensorboard
再次测试:
运行后生成了logs文件夹,里面是执行过的事件文件
打开事件文件
logdir=事件文件所在文件夹名
在Terminal中运行,点击链接即可:
tensorboard --logdir=logs
上面是默认的端口,还可以指定端口:
tensorboard --logdir=logs --port=6007
add_image()方法的使用
add_scalar()方法:
image的类型:
在pycharm工作台获取图片路径
image_path = "learn_pytorch/dataset/train/ants/0013035.jpg"
测试:
from PIL import Image
img = Image.open(image_path)
print(type(img))
PIL.JpegImagePlugin.JpegImageFile类型不满足要求。
利用numpy.array(),对PIL图片进行转换。
(另一种方法:利用Opencv读取图片,获得numpy型图片数据)
import numpy as np
img_array = np.array(img)
print(type(img_array))
从PIL到numpy,需要在add_image()中指定shape中每一个数字/维表示的含义,否则会报错。
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from PIL import Image
writer = SummaryWriter("logs")
image_path = "learn_pytorch/dataset/train/ants/0013035.jpg"
img_PIL = Image.open(image_path)
img_array = np.array(img_PIL)
print(type(img_array))
print(img_array.shape)
writer.add_image("test", img_array, 1, dataformats='HWC')
# y=2x
for i in range(100):
writer.add_scalar("y=2x", 3*i, i)
writer.close()
运行结果:
点开tensorboard会显示出图片:
更改图片地址,换一张图片,并改成第二步:
运行后tensorboard中的图片变成俩个图片滑动变换。
更改tag,运行后重新生成了一个单张图片: