文章目录
- 前言
- 一、数据主函数源码解读
- 1、图像函数源码调用解读
- 2、文本函数源码调用解读
- 3、tokenizer生成函数
- 4、llama2_text_processor文本处理函数解读
- 二、create_dataset_function函数源码代码解读
- 三、sat库之make_loaders函数源码解读
- 1、make_loaders函数调用说明
- 2、make_loaders函数源码解读
- 第一部分源码解读
- 第二部分源码解读
- 第三部分源码解读
- make_loaders源码展示
- 四、sat库之make_dataset_full函数源码解读
- 1、参数配置
- 2、数据格式说明
- 3、make_dataset_full数据读取源码解读
- 4、make_dataset_full源码展示
- 五、sat库之make_data_loader函数源码解读
- 六、ItemDataset数据类源码解读
- 1、图像数据处理
- 1、process_img(self, img)
- 2、self.image_processor(img)
- 2、文本数据处理
- 1、原始图像验证码文本数据处理
- 2、使用自己文本数据代码修改
前言
本文是CogVLM是一个多模态大型模型,它能够处理文本、图像和其他类型的数据。在数据处理方面,CogVLM可以接收多种类型的输入数据,包括文本、图像、音频等。然而,很少有人对代码数据处理进行解读或者基本找不到。基于此,本文将结合源码给出CogVLM大模型数据处理内容,主要包含图像数据处理、文本tokenizer构建、文本加工与修改自己文本方法代码修改。总之,我将结合代码一步一步带领读者实现大模型数据处理源码内容。
一、数据主函数源码解读
CogVLM的数据处理包含2部分,一部分是图像数据处理,另一部分是文本数据处理。其源码位于finetune_cogvlm_demo.py文件,如下:
from utils.utils import llama2_tokenizer
tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
image_processor = get_image_processor(args.eva_args["image_size"][0]) # 获得图像加工函数,并附image_size参数
text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) # 获得文本加工函数
model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval)
# 训练函数 参数 模型 训练机制 处理数据函数-->给该函数添加部分参数<---参数为函数 dataloader整合数据 评估函数
从上面可看出获得图像处理函数调用是通过get_image_processor,而获得文本处理函数是调用llama2_text_processor函数,然后通过training_main函数把图像处理与文本处理函数分别作为image_processor与text_processor参数传递。我将说明这2个函数如何处理成参数方法。
1、图像函数源码调用解读
图像处理实际是调用blip2_image_processor_func_with_inputs函数(后面我会详细解读),实现图像加工,最终作为image_processor函数参数。源码如下:
def get_image_processor(image_size):
return partial(blip2_image_processor_func_with_inputs, BlipImageEvalProcessor(image_size))
2、文本函数源码调用解读
文本处理实际是通过huggingface的方式调用tokenizer等方法,在通过llama2_text_processor类对文本处理,我后面会解读,这里介绍通过这样方式作为text_processor函数参数。源码如下:
from utils.utils import llama2_tokenizer
tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) # 获得文本加工函数
3、tokenizer生成函数
对文本的tokenizer处理,CogVLM调用huggingface的llama函数,使用from utils.utils import llama2_tokenizer
调用,在源码中使用函数如下:
tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
而对于语言函数
from transformers import LlamaTokenizer
def llama2_tokenizer(tokenizer_path, signal_type="base"):
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 32000
tokenizer.boi = "[IMG]"
tokenizer.eoi = "[/IMG]"
assert signal_type in ["base", "chat", "vqa", "chat_old"]
tokenizer.signal_type = signal_type
return tokenizer
4、llama2_text_processor文本处理函数解读
该函数是文本处理相关内容,我将其代码罗列如下:
class llama2_text_processor:
def __init__(self, tokenizer, max_target_length=2048, image_length=257, model=None):
self.tokenizer = tokenizer
self.max_target_length = max_target_length
self.image_length = image_length
def __call__(self, caption, prompt=""):
if '<EOI>' not in prompt:
prompt = self.replace_tags_with_empty(prompt)
# caption = self.replace_tags_with_empty(caption)
history = []
prompt = self.history_to_prompt(prompt, history)
input_ids = [self.tokenizer.bos_token_id]
prompt_splits = prompt.split('<EOI>')
caption_splits = caption.split('<EOI>')
if len(prompt_splits) > 0:
input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False))
for tokens in prompt_splits[1:]:
tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
input_ids.extend(tokens_with_img)
context_length = len(input_ids) + (len(prompt_splits)-1) * (self.image_length + 1)
if context_length > self.max_target_length - 10:
return None
if len(caption_splits) > 0:
input_ids.extend(self.tokenizer.encode(caption_splits[0], add_special_tokens=False))
for tokens in caption_splits[1:]:
tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
input_ids.extend(tokens_with_img)
if len(input_ids) > self.max_target_length - self.image_length - 5:
input_ids = input_ids[:self.max_target_length - self.image_length - 5]
input_ids += [self.tokenizer.eos_token_id]
while -100 in input_ids:
img_idx = input_ids.index(-100)
input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx+1:]
image_position = []
while -1 in input_ids:
img_idx = input_ids.index(-1)
input_ids[img_idx] = 0
image_position.append(img_idx)
image_embed_mask = [0] * len(input_ids)
vision_expert_mask = [0] * len(input_ids)
image_rope_mask = [0] * len(input_ids)
for idx in image_position:
image_embed_mask[idx-self.image_length-1: idx+1] = [1] * (self.image_length + 2)
vision_expert_mask[idx-self.image_length-1: idx] = [1] * (self.image_length + 1)
image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length
attention_mask = [1] * len(input_ids)
labels = [-100] * context_length + input_ids[context_length:]
pad_len = self.max_target_length - len(input_ids)
input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
attention_mask = attention_mask + [1] * pad_len
vision_expert_mask = vision_expert_mask + [0] * pad_len
image_embed_mask = image_embed_mask + [0] * pad_len
image_rope_mask = image_rope_mask + [0] * pad_len
np_mask = np.tril(np.expand_dims(np.array(attention_mask), 0).repeat(len(attention_mask), 0))
labels = labels + [-100] * pad_len
for idx in image_position:
labels[idx-self.image_length-1: idx+1] = [-100] * (self.image_length + 2)
position_ids = []
pid = -1
for i in range(len(input_ids)):
if image_rope_mask[i] == 0 or (i > 0 and image_rope_mask[i] != image_rope_mask[i - 1]):
pid += 1
position_ids.append(pid)
input_ids = torch.tensor(input_ids).unsqueeze(0)
labels = torch.tensor(labels).unsqueeze(0)
attention_mask = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0)
image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0)
vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0)
image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0)
position_ids = torch.tensor(position_ids).unsqueeze(0)
context_length = torch.tensor(context_length).unsqueeze(0).long()
return {'input_ids': input_ids, 'labels': labels, 'position_ids': position_ids, 'attention_mask': attention_mask, 'image_embed_mask': image_embed_mask,
'context_length': context_length, 'image_position': image_position, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask
}
def history_to_prompt(self, query, history):
return _history_to_prompt[self.tokenizer.signal_type](self, query, history)
def replace_tags_with_empty(self, text):
return re.sub('<pad>|<s>|</s>|<EOI>', '', text)
二、create_dataset_function函数源码代码解读
在使用model = training_main函数create_dataset_function=partial(create_dataset_function, image_processor, text_processor)
,借助python自带partial函数实现数据类处理,其源码如下:
from utils.utils import ItemDataset
def create_dataset_function(image_processor, text_processor, path, args):
dataset = ItemDataset(image_processor, text_processor, args, path)
return dataset
我们可发现该函数调用了ItemDataset类,该类恰好是数据加工的迭代器类,类似taorch的dataset等功能,实际也是继承torch的dataset类,进一步封装加工数据。其中image_processor, text_processor是上面提到图像加工与文本加工方法代码。
三、sat库之make_loaders函数源码解读
在上面,我们已给出training_main
函数,该函数集成数据处理、模型训练等方法,其中hooks保存不同方法,而数据处理代码为make_loaders。
1、make_loaders函数调用说明
make_loaders为模型数据加载方法,又是sat库集成方法,其源码如下:
# Data stuff. 数据处理方法适用
train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'], collate_fn=collate_fn) # 通过该函数调用
很明显,make_loaders参数为需要参数、处理数据类(集成dataset)、collate_fn函数(类似dataloader处理batch数据方式)。
源码位置:sat.data_utils.configure_data.py文件
2、make_loaders函数源码解读
我将maker_loaders源码分为三个部分,第一个部分是使用partial函数将传递数据类create_dataset_function(实际是ItemDataset类)赋参数并重命名函数为make_dataset;第二部分是否传训练、验证、测试路径,使用ItemDataset类处理数据,类似torch的dataset结构;第三部分也是调用sat库的make_data_loader包装对应数据,类似torch的dataloader结构,且调用自己传递的collate_fn方法。
第一部分源码解读
将我们对数据处理create_dataset_function类(实际是ItemDataset类)使用sat库的make_dataset_full函数包装,源码如下:
make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)
第二部分源码解读
根据传递路径参数,使用上面make_dataset方法处理数据,类似torch的dataset,源码如下:
# make datasets splits and tokenizer
train = None
valid = None
test = None
if args.train_data is not None:
train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
if should_split(split):
train, valid, test = train
# make training and val dataset if necessary
if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data
valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data
test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
第三部分源码解读
将我们处理的datset数据进行dataloader包装,类似torch的dataloader,且调用自己写的collate_fn方法,源码如下:
# wrap datasets with data loader
if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
args.do_train = True
else:
args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
args.do_valid = True
else:
args.do_valid = False
if test is not None:
test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
args.do_test = True
else:
args.do_test = False
make_loaders源码展示
另外,该函数也对相应参数做处理,如world_size等内容,所有源码如下:
def make_loaders(args, create_dataset_function, collate_fn=None):
"""makes training/val/test
Args:
args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.
args.split: str. format: "8,1,1". how to split train_data.
args.dataset_type: use to create the right datasets.
"""
make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)
world_size = torch.distributed.get_world_size( group=mpu.get_data_parallel_group())
batch_size = args.batch_size * world_size
eval_batch_size = batch_size
if args.eval_batch_size is not None:
eval_batch_size = args.eval_batch_size * world_size
split = get_split(args)
data_set_args = {
'path': args.train_data,
'split': split,
}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
# make datasets splits and tokenizer
train = None
valid = None
test = None
if args.train_data is not None:
train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
if should_split(split):
train, valid, test = train
# make training and val dataset if necessary
if valid is None and args.valid_data is not None:
eval_set_args['path'] = args.valid_data
valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
if test is None and args.test_data is not None:
eval_set_args['path'] = args.test_data
test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
# wrap datasets with data loader
if train is not None and args.batch_size > 0:
train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
args.do_train = True
else:
args.do_train = False
eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
if valid is not None:
valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
args.do_valid = True
else:
args.do_valid = False
if test is not None:
test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
args.do_test = True
else:
args.do_test = False
return train, valid, test
四、sat库之make_dataset_full函数源码解读
1、参数配置
我是在vscode运行,我对训练数据参数配置作为列子,可配置三种方式,假设CogVLM-SFT-311K文件夹有2个文件,分别为llava_instruction_multi_conversations_formate与llava_instruction_single_conversation_formate文件,可单独给子文件路径、也可给子文件上一个文件路径、也可给多个文件列表路径,具体如下:
# 第一种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate",
# 第二种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K",
# 第三种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_single_conversation_formate" ,
2、数据格式说明
单个数据文件夹内容如下:
json文件内容如下:
3、make_dataset_full数据读取源码解读
make_loaders函数中的make_dataset类调用是被make_dataset_full函数包装,主要处理一些逻辑,使其进入数据类为统一格式,其源码如下:
ds = []
for p in path:
d = create_dataset_function(p, args)
ds.append(d)
ds = ConcatDataset(ds, weights=dataset_weights)
以上,可看到path类似给定参数,然后使用os.walk遍历所有.jpg格式数据,并给成绝对路径,且将所有数据cat为一个列表。
4、make_dataset_full源码展示
def make_dataset_full(path, split, args, create_dataset_function,
dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs):
"""function to create datasets+tokenizers for common options"""
print_all('make dataset ' + str(path), level='DEBUG')
assert isinstance(path, list)
if args.iterable_dataset: # cannot indexed
# the random mapping is flexible and efficient, but sometimes we have pratical issue
# For instance, someone just gives you a iterable dataset, e.g. webdataset
from .webds import ConfiguredResampledShards, DataPipeline
valid_types = (ConfiguredResampledShards, DataPipeline)
assert split[0] == 1, 'Iterable dataset cannot auto split.'
ds = []
for p in path:
d = create_dataset_function(p, args)
assert isinstance(d, valid_types)
ds.append(d)
# ds = ChainDataset(ds) # please merge them in a url if chain
if batch_from_same_dataset:
assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.'
ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size)
return ds
if split is None:
split = [1.]
if not should_split(split):
ds = []
for p in path:
d = create_dataset_function(p, args)
ds.append(d)
ds = ConcatDataset(ds, weights=dataset_weights)
if random_mapping:
if args.epochs is not None: # not auto-scale, but use a given number of epoches.
ds = RandomDataset(ds, scale=args.epochs, seed=args.seed)
else:
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
if is_train_data:
# only train-dataset will set this to True,
# so we enlarge it to make sure that the data is sufficient.
scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
else:
scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
ds = RandomMappingDataset(ds, scale=scale)
return ds [-1, 9, C3, [512]],
else:
# must first split datasets, then reweight/concat, finally random-mapping.
# this order avoids overlapping.
train_ds, valid_ds, test_ds = [], [], []
for p in path:
d = create_dataset_function(p, args)
if should_split(split):
dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed)
train_ds.append(dtrain)
valid_ds.append(dvalid)
test_ds.append(dtest)
train_ds = ConcatDataset(train_ds, weights=dataset_weights)
valid_ds = ConcatDataset(valid_ds, weights=dataset_weights)
test_ds = ConcatDataset(test_ds, weights=dataset_weights)
if random_mapping:
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds))
train_ds = RandomMappingDataset(train_ds, scale=scale)
scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds))
valid_ds = RandomMappingDataset(valid_ds, scale=scale)
test_ds = RandomMappingDataset(test_ds)
return train_ds, valid_ds, test_ds
五、sat库之make_data_loader函数源码解读
在三说过类似torch的dataloader方法,在这里我大致说下,这里loader主要对dataset包装,使用顺序采用,配置类似的world_size与相应环境等内容,这也是库本身包装好的,可直接使用,其源码如下:
def make_data_loader(dataset, batch_size, args, split, collate_fn=None):
world_size = torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
distributed = world_size > 1
# if IterableDataset, assume everything is properly configured. (pre-sharded)
if isinstance(dataset, IterableDataset):
if split in ['val', 'test'] and args.strict_eval:
raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.')
args.val_last_shape = [1] * world_size # just fake it, not actually used
args.val_drop_number = 0
args.test_last_shape = [1] * world_size
args.test_drop_number = 0
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size//world_size,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
)
sampler = torch.utils.data.SequentialSampler(dataset) # 顺序采样
drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number.
# the GPUs in the same model parallel group receive the same data
if distributed: # TODO reformat this, but it is not urgent
gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
batch_sampler = DistributedBatchSampler(sampler,
batch_size,
drop_last,
rank,
world_size,
gradient_accumulation_steps=gradient_accumulation_steps)
else:
batch_sampler = torch.utils.data.BatchSampler(sampler,
batch_size,
drop_last)
last_len = len(dataset) % batch_size
batch_per_worker = batch_size // world_size
last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch
if last_len != 0:
if last_len % batch_per_worker != 0:
last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch)
drop_number = world_size - ((last_len-1)//batch_per_worker + 1)
# other processes get nothing, but append 1 for running. will drop later according to drop_number.
for j in range(drop_number):
last_shape.append(1)
else:
drop_number = 0
if split=='val':
args.val_last_shape = last_shape
args.val_drop_number = drop_number
elif split=='test':
args.test_last_shape = last_shape
args.test_drop_number = drop_number
data_loader = torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_fn,
prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
)
return data_loader
六、ItemDataset数据类源码解读
在这里,终于进入最核心部分,数据加工
class ItemDataset(Dataset):
def __init__(self, image_processor, text_processor, args, data_dirs, cross_image_processor=None, **kwargs):
super().__init__()
self.data = self.load_data(data_dirs) # 获得.jpg图片绝对路径的列表,保存到self.data变量中
self.image_processor, self.text_processor, self.cross_image_processor = image_processor, text_processor, cross_image_processor # 传递的数据加工函数
def process_img(self, img):
img_dict = {'vision': self.image_processor(img)}
if self.cross_image_processor:
img_dict.update({'cross': self.cross_image_processor(img)})
return img_dict
def process_text(self, answer, prompt):
return self.text_processor(answer, prompt)
def load_data(self, data_dir):
all_files = find_all_files(data_dir, suffix=".jpg")
print_rank0(f"find {len(all_files)} samples in all...")
return all_files
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index] # 获得图片的绝对路径
# img
try:
img = Image.open(data).convert('RGB') # 载入图片
except Exception as e:
print_rank0(e, level=logging.WARNING)
return {}
img_dict = self.process_img(img) # 图像加工
# text
label = data.split('/')[-1].split('.')[0]
uni_key = label
text_dict = self.process_text(label, "CAPTCHA:")
if text_dict is None:
print_rank0(f"Process text failed. Please check the max_target_length & max_source_length.\n The data is {data}", level=logging.WARNING)
return {}
# other attr
ret = {**img_dict, **text_dict, "question_id": uni_key}
return ret
1、图像数据处理
1、process_img(self, img)
图像处理最终为一个img_dict的字典,包含2部分处理,其源码如下:
def process_img(self, img):
img_dict = {'vision': self.image_processor(img)}
if self.cross_image_processor:
img_dict.update({'cross': self.cross_image_processor(img)})
return img_dict
2、self.image_processor(img)
image_processor函数传递为blip2_image_processor_func_with_inputs函数
def blip2_image_processor_func_with_inputs(image_processor, image):
return {'image': image_processor(image).unsqueeze(0), 'input_ids': torch.zeros(1, 1, dtype=torch.long), 'position_ids': None, 'attention_mask': torch.ones(1, 1, dtype=torch.long)}
2、文本数据处理
1、原始图像验证码文本数据处理
原始验证码数据是一个图片,其名字为验证码数字命名,在源码data表示图像绝对路径(/home/*/0a4Ovs8789.jpg)。因此,使用以下方式label = data.split(‘/’)[-1].split(‘.’)[0]即可获得文本(0a4Ovs8789),随后使用self.process_text函数即可实现验证码文本数据。
label = data.split('/')[-1].split('.')[0]
uni_key = label
text_dict = self.process_text(label, "CAPTCHA:")
验证码数据图如下:
2、使用自己文本数据代码修改
假如一个图片对应一个数据json文件,其内容如下:
{
"conversations": [
{
"role": "assistant",
"content": "虽然无法从照片中确定他们的确切目的地或目的地,但这两名滑板运动员很可能正在使用公共交通工具前往滑板公园、休闲场所或其他可以练习滑板的地方。他们也可以在空闲时间携带滑板进行娱乐活动,往返于学校、工作或其他日常活动。或者,他们可能只是带着滑板在不同地点之间旅行,作为首选的交通方式。"
}
]
}
代码内容是获取json文件相应内容:
json_path = data.replace('images','labels_zh')[:-4]+'.json'
label=self.read_json(json_path)
label = label['captions'][0]['content'] # 获取描述内容
uni_key = label
text_dict = self.process_text(label, "CAPTCHA:")
读取json辅助代码如下:
def read_json(self,json_root):
import json
with open(json_root, encoding='utf-8') as f:
json_info = json.load(f)
return json_info
最后文本将使用以下方式加工文本,我后期有时间在具体解读。
text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) # 获得文本加工函数