文章目录
- 1. 迭代器与可迭代对象(Iterable)
- 1.1 可迭代对象(Iterable)
- 1.2 迭代器( Iterator)
- 2. 自定义一个可迭代器
- 2.1 实现迭代器
- 2.2 for 遍历迭代器的过程
- 3. yolov8 Dataset实现案例
Python迭代器的作用是提供一种
遍历数据集合
的方式。它是一个可以被迭代
的对象,可以使用迭代器的方法来逐个访问集合中的元素,而不需要事先知道集合
的大小。
在深度学习的Dataset
和Dataloader
中,就是通过迭代器
实现的,因此迭代器是一个非常重要的概念和工具
迭代器具有以下几个重要的特点:
节省内存
:迭代器一次只返回一个元素,不需要一次性将整个集合加载到内存中
,这样可以节省内存空间,特别是在处理大型数据集合时非常有用。惰性计算
:迭代器在需要时才会计算下一个元素
,而不是一次性计算所有的元素。这种惰性计算的方式可以在处理大量数据时提高效率。可逆迭代
:迭代器可以反向遍历集合,而不需要额外的复制和存储。支持并行处理
:迭代器可以同时遍历多个集合,实现并行处理。
总之,迭代器提供了一种灵活、高效和节省内存的方式来遍历数据集合,是Python中非常重要的概念和工具。
1. 迭代器与可迭代对象(Iterable)
1.1 可迭代对象(Iterable)
表示该对象可迭代
, 它的类中需要定义__iter__
方法,只要是实现了__iter__方法的类就是可迭代对象
。
from collections.abc import Iterable, Iterator
class A(object):
def __init__(self):
self.a = [1, 2, 3]
def __iter__(self):
# 此处返回啥无所谓
return self.a
cls_a = A()
# True
print(isinstance(cls_a, Iterable))
-
可迭代对象,必须具备
__iter__
这个特殊函数,并且返回
一个可迭代对象。可以通过isinstance(cls_a, Iterable)
可以判断是否是可迭代对象 -
如果一个Iterable,仅仅定义了
__iter__
方法,是没有特别大的用途,因为依然无法迭代,实际上 Iterable 仅仅是提供了一种抽象规范接口
1.2 迭代器( Iterator)
迭代器Iterator一定是可迭代对象Iterable
,但反过来,可迭代对象不一定是迭代器
,因为迭代器只是可迭代对象的一种表示形式。- 实现了
__next__
和__iter__
方法的类才能称为迭代器
就可以被 for 循环遍历数据。
因此,通过自定义实现迭代器Iterator
,必须具备__next__
和__iter__
两个方法,如下案例所示:
class A(object):
def __init__(self):
self.index = -1
self.a = [1, 2, 3]
# 必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历
# 因为本类自身实现了 __next__,所以通常都是返回 self 对象即可
def __iter__(self):
return self
def __next__(self):
self.index += 1
if self.index < len(self.a):
return self.a[self.index]
else:
# 抛异常,for 内部会自动捕获,表示迭代完成
raise StopIteration("遍历完了")
cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True 从这里可以看出来同时具有__iter__和__next__的类,是Iterator
print(isinstance(iter(cls_a), Iterator)) # True 这里加不加iter()都一样,因为这个类里面的iter也是直接返回自身(self)
#另外补充一点这个a和上面类里面的a是不一样的;这里的用i(任意字母都可以)也能
for a in cls_a:
print(a)
# 打印 1 2 3
-可以看到,通过实现__iter__
, 和__next__
这两个特殊方法,实现了迭代器A。
2. 自定义一个可迭代器
2.1 实现迭代器
在Python中从头开始构建迭代器很容易。我们只需要实现
这些方法__iter__()
和__next__()
。
__iter__()
方法需要返回迭代器对象
, 最简单直接返回self
,也可以返回新的可迭代对象。如果需要,可以执行一些初始化
。__next__()
方法必须返回序列中的下一项
。在到达终点时,以及在随后的调用中,它必须引发StopIteration
这里,我们展示了一个示例,通过定义一个迭代器
,手动实现python的range
方法:
class Range:
def __init__(self,start,stop,step):
self.start = start
self.stop = stop
self.step = step
def __iter__(self):
self.value = self.start
return self
def __next__(self):
# 1. 每执行一次next,需要返回一个值
# 2. 如果没有下一个值了,需要通过StopIteration 反馈异常
if self.value < self.stop:
old_value = self.value
self.value = self.value +self.step
return old_value
else:
raise StopIteration()
for i in Range(0,5,1):
print(i)
输出结果
我们知道python 的range方法 有三个参数:start_value, stop_value和step,因此我们也定义这3个参数。
- 首先定义类的
__init__
方法 - 然后实现
__iter__
方法,并返回可迭代对象,这里返回了本身(slef)。其中在__iter__
方法中,初始化了返回值self.value
(__iter__
方法中,如果需要,可以执行一些初始化
) - 最后通过
__next__
方法,定义每一次迭代输出的值。当迭代完了,没有下一个值,通过raise StopIteration()
, 反馈错误。
在for循环中,最后反馈的raise StopIteration()
erro,会被for循环
处理掉,因此我们看不到报错的提醒。
2.2 for 遍历迭代器的过程
通过单步调试,可以观察到如下执行顺序:
- (1) 首先调用
__init__
方法,对成员变量进行初始化 - (2) 紧接着,进入
__iter__
,获得一个迭代器对象self - (3) 最后进入
__next__
方法,执行循环迭代过程
,每迭代一次返回相应的迭代结果。最后迭代会执行raise StopIteration()
语句,此时程序结束
(异常被for处理,所以没有显示出来)
可以看到所谓for循环,本质上是就是一次次的执行__next__
方法的过程。for循环可等价如下代码:
r = Range(0,5,1) # __init__ 构造对象
iteration = iter(r) # 执行__iter__,获得迭代器对象
#iteration = r.__iter__()
next(iteration) # 执行__next__
#iteration.__next__()
next(iteration) # 执行__next__
next(iteration) # 执行__next__
next(iteration) # 执行__next__
next(iteration) # 执行__next__
try:
next(iteration)
except StopIteration as e:
pass
- 最后一次会报
StopIteration
异常,因此通过try except
进行处理。 - 注:`
iter(r)
就是r.__iter__()
实现的next(iteration)
就是通过iteration.__next__()
实现的
3. yolov8 Dataset实现案例
class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initialize instance variables and check for valid input."""
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit()
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
if '*' in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
elif os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
elif os.path.isfile(p):
files.append(p) # files
else:
raise FileNotFoundError(f'{p} does not exist')
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.img_size = img_size
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = 'image'
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
if any(videos):
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
def __iter__(self):
"""Returns an iterator object for iterating over images or videos found in a directory."""
self.count = 0
return self
def __next__(self):
"""Iterator's next item, performs transformation on image and returns path, transformed image, original image, capture and size."""
if self.count == self.nf:
raise StopIteration
path = self.files[self.count]
if self.video_flag[self.count]:
# Read video
self.mode = 'video'
for _ in range(self.vid_stride):
self.cap.grab()
ret_val, im0 = self.cap.retrieve()
while not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self._new_video(path)
ret_val, im0 = self.cap.read()
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
else:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}'
s = f'image {self.count}/{self.nf} {path}: '
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
return path, im, im0, self.cap, s
dataset = LoadImages(source, imgsz=imgsz, vid_stride=vid_stride)
dataloader = build_dataloader(dataset, batch_size, workers, shuffle, rank)
- 可以看到Dataset,也是通过一个迭代器实现,这样做的
好处
就是:需要时才会返回处理好的数据
,而不需要一次性将整个集合加载到内存中
,这样可以节省内存空间,也提高了数据处理的效率。