CogVLM训练源码解读--数据处理

文章目录

  • 前言
  • 一、数据主函数源码解读
    • 1、图像函数源码调用解读
    • 2、文本函数源码调用解读
    • 3、tokenizer生成函数
    • 4、llama2_text_processor文本处理函数解读
  • 二、create_dataset_function函数源码代码解读
  • 三、sat库之make_loaders函数源码解读
    • 1、make_loaders函数调用说明
    • 2、make_loaders函数源码解读
      • 第一部分源码解读
      • 第二部分源码解读
      • 第三部分源码解读
      • make_loaders源码展示
  • 四、sat库之make_dataset_full函数源码解读
    • 1、参数配置
    • 2、数据格式说明
    • 3、make_dataset_full数据读取源码解读
    • 4、make_dataset_full源码展示
  • 五、sat库之make_data_loader函数源码解读
  • 六、ItemDataset数据类源码解读
    • 1、图像数据处理
      • 1、process_img(self, img)
      • 2、self.image_processor(img)
    • 2、文本数据处理
      • 1、原始图像验证码文本数据处理
      • 2、使用自己文本数据代码修改


前言

本文是CogVLM是一个多模态大型模型,它能够处理文本、图像和其他类型的数据。在数据处理方面,CogVLM可以接收多种类型的输入数据,包括文本、图像、音频等。然而,很少有人对代码数据处理进行解读或者基本找不到。基于此,本文将结合源码给出CogVLM大模型数据处理内容,主要包含图像数据处理、文本tokenizer构建、文本加工与修改自己文本方法代码修改。总之,我将结合代码一步一步带领读者实现大模型数据处理源码内容。


一、数据主函数源码解读

CogVLM的数据处理包含2部分,一部分是图像数据处理,另一部分是文本数据处理。其源码位于finetune_cogvlm_demo.py文件,如下:

from utils.utils import llama2_tokenizer
    tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
    image_processor = get_image_processor(args.eva_args["image_size"][0])  # 获得图像加工函数,并附image_size参数
    text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)  # 获得文本加工函数

    model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval)
    #        训练函数      参数         模型                    训练机制                                    处理数据函数-->给该函数添加部分参数<---参数为函数                               dataloader整合数据                 评估函数               

从上面可看出获得图像处理函数调用是通过get_image_processor,而获得文本处理函数是调用llama2_text_processor函数,然后通过training_main函数把图像处理与文本处理函数分别作为image_processor与text_processor参数传递。我将说明这2个函数如何处理成参数方法。

1、图像函数源码调用解读

图像处理实际是调用blip2_image_processor_func_with_inputs函数(后面我会详细解读),实现图像加工,最终作为image_processor函数参数。源码如下:

def get_image_processor(image_size):
    return partial(blip2_image_processor_func_with_inputs, BlipImageEvalProcessor(image_size))

2、文本函数源码调用解读

文本处理实际是通过huggingface的方式调用tokenizer等方法,在通过llama2_text_processor类对文本处理,我后面会解读,这里介绍通过这样方式作为text_processor函数参数。源码如下:

    from utils.utils import llama2_tokenizer
    tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
    text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)  # 获得文本加工函数

3、tokenizer生成函数

对文本的tokenizer处理,CogVLM调用huggingface的llama函数,使用from utils.utils import llama2_tokenizer调用,在源码中使用函数如下:

tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version) 

而对于语言函数

from transformers import LlamaTokenizer
def llama2_tokenizer(tokenizer_path, signal_type="base"):
    tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = 32000
    tokenizer.boi = "[IMG]"
    tokenizer.eoi = "[/IMG]"
    assert signal_type in ["base", "chat", "vqa", "chat_old"]
    tokenizer.signal_type = signal_type
    return tokenizer

4、llama2_text_processor文本处理函数解读

该函数是文本处理相关内容,我将其代码罗列如下:

class llama2_text_processor:
    def __init__(self, tokenizer, max_target_length=2048, image_length=257, model=None):
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length
        self.image_length = image_length

    def __call__(self, caption, prompt=""):
        if '<EOI>' not in prompt:
            prompt = self.replace_tags_with_empty(prompt)
            # caption = self.replace_tags_with_empty(caption)
            history = []
            prompt = self.history_to_prompt(prompt, history)

        input_ids = [self.tokenizer.bos_token_id]

        prompt_splits = prompt.split('<EOI>')
        caption_splits = caption.split('<EOI>')
        if len(prompt_splits) > 0:
            input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False))
        for tokens in prompt_splits[1:]:
            tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
            input_ids.extend(tokens_with_img)
        context_length = len(input_ids) + (len(prompt_splits)-1) * (self.image_length + 1)
        if context_length > self.max_target_length - 10:
            return None
        if len(caption_splits) > 0:
            input_ids.extend(self.tokenizer.encode(caption_splits[0], add_special_tokens=False))
        for tokens in caption_splits[1:]:
            tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
            input_ids.extend(tokens_with_img)

        if len(input_ids) > self.max_target_length - self.image_length - 5:
            input_ids = input_ids[:self.max_target_length - self.image_length - 5]

        input_ids += [self.tokenizer.eos_token_id]

        while -100 in input_ids:
            img_idx = input_ids.index(-100)
            input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx+1:]

        image_position = []
        while -1 in input_ids:
            img_idx = input_ids.index(-1)
            input_ids[img_idx] = 0
            image_position.append(img_idx)

        image_embed_mask = [0] * len(input_ids)
        vision_expert_mask = [0] * len(input_ids)
        image_rope_mask = [0] * len(input_ids)
        for idx in image_position:
            image_embed_mask[idx-self.image_length-1: idx+1] = [1] * (self.image_length + 2)
            vision_expert_mask[idx-self.image_length-1: idx] = [1] * (self.image_length + 1)
            image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length
        attention_mask = [1] * len(input_ids)
        labels = [-100] * context_length + input_ids[context_length:]

        pad_len = self.max_target_length - len(input_ids)
        input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
        attention_mask = attention_mask + [1] * pad_len
        vision_expert_mask = vision_expert_mask + [0] * pad_len
        image_embed_mask = image_embed_mask + [0] * pad_len
        image_rope_mask = image_rope_mask + [0] * pad_len
        np_mask = np.tril(np.expand_dims(np.array(attention_mask), 0).repeat(len(attention_mask), 0))
        labels = labels + [-100] * pad_len

        for idx in image_position:
            labels[idx-self.image_length-1: idx+1] = [-100] * (self.image_length + 2)

        position_ids = []
        pid = -1
        for i in range(len(input_ids)):
            if image_rope_mask[i] == 0 or (i > 0 and image_rope_mask[i] != image_rope_mask[i - 1]):
                pid += 1
            position_ids.append(pid)

        input_ids = torch.tensor(input_ids).unsqueeze(0)
        labels = torch.tensor(labels).unsqueeze(0)
        attention_mask = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0)
        image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0)
        vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0)
        image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0)
        position_ids = torch.tensor(position_ids).unsqueeze(0)
        context_length = torch.tensor(context_length).unsqueeze(0).long()
        return {'input_ids': input_ids, 'labels': labels, 'position_ids': position_ids, 'attention_mask': attention_mask, 'image_embed_mask': image_embed_mask,
                'context_length': context_length, 'image_position': image_position, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask
                }

    def history_to_prompt(self, query, history):
        return _history_to_prompt[self.tokenizer.signal_type](self, query, history)

    def replace_tags_with_empty(self, text):
        return re.sub('<pad>|<s>|</s>|<EOI>', '', text)

二、create_dataset_function函数源码代码解读

在使用model = training_main函数create_dataset_function=partial(create_dataset_function, image_processor, text_processor),借助python自带partial函数实现数据类处理,其源码如下:


from utils.utils import ItemDataset
def create_dataset_function(image_processor, text_processor, path, args):
    dataset = ItemDataset(image_processor, text_processor, args, path)
    return dataset

我们可发现该函数调用了ItemDataset类,该类恰好是数据加工的迭代器类,类似taorch的dataset等功能,实际也是继承torch的dataset类,进一步封装加工数据。其中image_processor, text_processor是上面提到图像加工与文本加工方法代码。

三、sat库之make_loaders函数源码解读

在上面,我们已给出training_main函数,该函数集成数据处理、模型训练等方法,其中hooks保存不同方法,而数据处理代码为make_loaders。

1、make_loaders函数调用说明

make_loaders为模型数据加载方法,又是sat库集成方法,其源码如下:

# Data stuff. 数据处理方法适用
    train_data, val_data, test_data = make_loaders(args, hooks['create_dataset_function'], collate_fn=collate_fn) # 通过该函数调用

很明显,make_loaders参数为需要参数、处理数据类(集成dataset)、collate_fn函数(类似dataloader处理batch数据方式)。

源码位置:sat.data_utils.configure_data.py文件

2、make_loaders函数源码解读

我将maker_loaders源码分为三个部分,第一个部分是使用partial函数将传递数据类create_dataset_function(实际是ItemDataset类)赋参数并重命名函数为make_dataset;第二部分是否传训练、验证、测试路径,使用ItemDataset类处理数据,类似torch的dataset结构;第三部分也是调用sat库的make_data_loader包装对应数据,类似torch的dataloader结构,且调用自己传递的collate_fn方法。

第一部分源码解读

将我们对数据处理create_dataset_function类(实际是ItemDataset类)使用sat库的make_dataset_full函数包装,源码如下:

    make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)

第二部分源码解读

根据传递路径参数,使用上面make_dataset方法处理数据,类似torch的dataset,源码如下:

# make datasets splits and tokenizer
    train = None
    valid = None
    test = None

    if args.train_data is not None:
        train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
        if should_split(split):
            train, valid, test = train

    # make training and val dataset if necessary
    if valid is None and args.valid_data is not None:
        eval_set_args['path'] = args.valid_data
        valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
    if test is None and args.test_data is not None:
        eval_set_args['path'] = args.test_data
        test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)

第三部分源码解读

将我们处理的datset数据进行dataloader包装,类似torch的dataloader,且调用自己写的collate_fn方法,源码如下:

# wrap datasets with data loader
    if train is not None and args.batch_size > 0:
        train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
        args.do_train = True
    else:
        args.do_train = False
    eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
    if valid is not None:
        valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
        args.do_valid = True
    else:
        args.do_valid = False
    if test is not None:
        test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
        args.do_test = True
    else:
        args.do_test = False

make_loaders源码展示

另外,该函数也对相应参数做处理,如world_size等内容,所有源码如下:

def make_loaders(args, create_dataset_function, collate_fn=None):
    """makes training/val/test
    Args:
        args.train_data, args.valid_data, args.test_data: str. Paths to the dataset.
        args.split: str. format: "8,1,1". how to split train_data.
        args.dataset_type: use to create the right datasets. 
    """
    make_dataset = partial(make_dataset_full, create_dataset_function=create_dataset_function, batch_from_same_dataset=args.batch_from_same_dataset)

    world_size = torch.distributed.get_world_size(   group=mpu.get_data_parallel_group())
    batch_size = args.batch_size * world_size
    eval_batch_size = batch_size
    if args.eval_batch_size is not None:
        eval_batch_size = args.eval_batch_size * world_size
    
    split = get_split(args)

    data_set_args = {
        'path': args.train_data,
        'split': split,
    }

    eval_set_args = copy.copy(data_set_args)
    eval_set_args['split'] = [1.]
    
    # make datasets splits and tokenizer
    train = None
    valid = None
    test = None

    if args.train_data is not None:
        train = make_dataset(**data_set_args, args=args, dataset_weights=args.train_data_weights, is_train_data=True)
        if should_split(split):
            train, valid, test = train

    # make training and val dataset if necessary
    if valid is None and args.valid_data is not None:
        eval_set_args['path'] = args.valid_data
        valid = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)
    if test is None and args.test_data is not None:
        eval_set_args['path'] = args.test_data
        test = make_dataset(**eval_set_args, args=args, random_mapping=not args.strict_eval)

    # wrap datasets with data loader
    if train is not None and args.batch_size > 0:
        train = make_data_loader(train, batch_size, args, split='train', collate_fn=collate_fn)
        args.do_train = True
    else:
        args.do_train = False
    eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size
    if valid is not None:
        valid = make_data_loader(valid, eval_batch_size, args, split='val', collate_fn=collate_fn)
        args.do_valid = True
    else:
        args.do_valid = False
    if test is not None:
        test = make_data_loader(test, eval_batch_size, args, split='test', collate_fn=collate_fn)
        args.do_test = True
    else:
        args.do_test = False

    return train, valid, test

四、sat库之make_dataset_full函数源码解读

1、参数配置

我是在vscode运行,我对训练数据参数配置作为列子,可配置三种方式,假设CogVLM-SFT-311K文件夹有2个文件,分别为llava_instruction_multi_conversations_formate与llava_instruction_single_conversation_formate文件,可单独给子文件路径、也可给子文件上一个文件路径、也可给多个文件列表路径,具体如下:

# 第一种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate",
# 第二种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K",
# 第三种方式:
"--train-data", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate", "/extend_disk/tj/data/CogVLM-SFT-311K/llava_instruction_single_conversation_formate" ,           

2、数据格式说明

单个数据文件夹内容如下:
在这里插入图片描述

json文件内容如下:
在这里插入图片描述

3、make_dataset_full数据读取源码解读

make_loaders函数中的make_dataset类调用是被make_dataset_full函数包装,主要处理一些逻辑,使其进入数据类为统一格式,其源码如下:

ds = []
for p in path:
    d = create_dataset_function(p, args)
    ds.append(d)
ds = ConcatDataset(ds, weights=dataset_weights)

以上,可看到path类似给定参数,然后使用os.walk遍历所有.jpg格式数据,并给成绝对路径,且将所有数据cat为一个列表。

4、make_dataset_full源码展示

def make_dataset_full(path, split, args, create_dataset_function, 
        dataset_weights=None, random_mapping=True, is_train_data=False, batch_from_same_dataset=False, **kwargs):
    """function to create datasets+tokenizers for common options"""
    print_all('make dataset ' + str(path), level='DEBUG')
    assert isinstance(path, list)

    if args.iterable_dataset: # cannot indexed
        # the random mapping is flexible and efficient, but sometimes we have pratical issue
        # For instance, someone just gives you a iterable dataset, e.g. webdataset
        from .webds import ConfiguredResampledShards, DataPipeline
        valid_types = (ConfiguredResampledShards, DataPipeline)
        
        assert split[0] == 1, 'Iterable dataset cannot auto split.'
        ds = []
        for p in path:
            d = create_dataset_function(p, args)
            assert isinstance(d, valid_types)
            ds.append(d)
        # ds = ChainDataset(ds) # please merge them in a url if chain
        if batch_from_same_dataset:
            assert args.num_workers <= 1, 'We cannot control the actual speed of different workers, may mix different iterable parts.'
        ds = AlterDataset(ds, weights=dataset_weights, seed=args.seed, batch_from_same_dataset=batch_from_same_dataset, batch_size=args.batch_size)
        return ds

    if split is None:
        split = [1.] 
    if not should_split(split):
        ds = []
        for p in path:
            d = create_dataset_function(p, args)
            ds.append(d)
        ds = ConcatDataset(ds, weights=dataset_weights)
        if random_mapping:
            if args.epochs is not None: # not auto-scale, but use a given number of epoches.
                ds = RandomDataset(ds, scale=args.epochs, seed=args.seed)
            else:
                world_size = torch.distributed.get_world_size(
                    group=mpu.get_data_parallel_group())
                if is_train_data:
                # only train-dataset will set this to True,
                # so we enlarge it to make sure that the data is sufficient.
                    scale = max(200, 1 + (args.train_iters * args.batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
                else:
                    scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(ds))
                ds = RandomMappingDataset(ds, scale=scale)
        return ds   [-1, 9, C3, [512]],
    else:
        # must first split datasets, then reweight/concat, finally random-mapping.
        # this order avoids overlapping.
        train_ds, valid_ds, test_ds = [], [], []
        for p in path:
            d = create_dataset_function(p, args)
            if should_split(split):
                dtrain, dvalid, dtest = split_ds(d, split, block_size=args.block_size, seed=args.seed)
                train_ds.append(dtrain)
                valid_ds.append(dvalid)
                test_ds.append(dtest)
        train_ds = ConcatDataset(train_ds, weights=dataset_weights)
        valid_ds = ConcatDataset(valid_ds, weights=dataset_weights)
        test_ds = ConcatDataset(test_ds, weights=dataset_weights)
        if random_mapping:
            world_size = torch.distributed.get_world_size(
                group=mpu.get_data_parallel_group())
            scale = max(200, 1 + (args.train_iters * args.batch_size * world_size) // len(train_ds))
            train_ds = RandomMappingDataset(train_ds, scale=scale)
            scale = max(200, 1 + ((1 + args.train_iters // args.eval_interval) * args.eval_iters * args.eval_batch_size * args.gradient_accumulation_steps * world_size) // len(valid_ds))
            valid_ds = RandomMappingDataset(valid_ds, scale=scale)
            test_ds = RandomMappingDataset(test_ds)
        return train_ds, valid_ds, test_ds

五、sat库之make_data_loader函数源码解读

在三说过类似torch的dataloader方法,在这里我大致说下,这里loader主要对dataset包装,使用顺序采用,配置类似的world_size与相应环境等内容,这也是库本身包装好的,可直接使用,其源码如下:

def make_data_loader(dataset, batch_size, args, split, collate_fn=None):

    world_size = torch.distributed.get_world_size(
        group=mpu.get_data_parallel_group())
    rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group())
    distributed = world_size > 1

    # if IterableDataset, assume everything is properly configured. (pre-sharded) 
    if isinstance(dataset, IterableDataset):
        if split in ['val', 'test'] and args.strict_eval:
            raise ValueError('IterableDataset cannot be used for validation or testing if `args.strict_eval=True`, because we cannot infer the length of the final batch before reading out them.')
        args.val_last_shape = [1] * world_size # just fake it, not actually used
        args.val_drop_number = 0
        args.test_last_shape = [1] * world_size
        args.test_drop_number = 0
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size//world_size,
            num_workers=args.num_workers,
            pin_memory=True,
            collate_fn=collate_fn,
            prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
            )

    sampler = torch.utils.data.SequentialSampler(dataset)  # 顺序采样

    drop_last = False # COMMENT: this is already solved by the complex logic of last_shape and drop_number.

    # the GPUs in the same model parallel group receive the same data
    if distributed: # TODO reformat this, but it is not urgent
        gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 1)
        batch_sampler = DistributedBatchSampler(sampler,
                                                batch_size,
                                                drop_last,
                                                rank,
                                                world_size,
                                                gradient_accumulation_steps=gradient_accumulation_steps)
    else:
        batch_sampler = torch.utils.data.BatchSampler(sampler,
                                                      batch_size,
                                                      drop_last)
    last_len = len(dataset) % batch_size
    batch_per_worker = batch_size // world_size
    last_shape = [batch_per_worker] * (last_len//batch_per_worker) # some processes get full batch
    if last_len != 0:
        if last_len % batch_per_worker != 0:
            last_shape.append(last_len % batch_per_worker) # one process get the rest (<1 batch)
        drop_number = world_size - ((last_len-1)//batch_per_worker + 1)
        # other processes get nothing, but append 1 for running. will drop later according to drop_number.
        for j in range(drop_number): 
            last_shape.append(1)
    else:
        drop_number = 0
    if split=='val':
        args.val_last_shape = last_shape
        args.val_drop_number = drop_number
    elif split=='test':
        args.test_last_shape = last_shape
        args.test_drop_number = drop_number
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_sampler=batch_sampler,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              collate_fn=collate_fn,
                                              prefetch_factor=args.prefetch_factor if args.num_workers > 0 else None,
                                              )
    return data_loader


六、ItemDataset数据类源码解读

在这里,终于进入最核心部分,数据加工

class ItemDataset(Dataset):
    def __init__(self, image_processor, text_processor, args, data_dirs, cross_image_processor=None, **kwargs):
        super().__init__()
        self.data = self.load_data(data_dirs)  # 获得.jpg图片绝对路径的列表,保存到self.data变量中
        self.image_processor, self.text_processor, self.cross_image_processor = image_processor, text_processor, cross_image_processor # 传递的数据加工函数

    def process_img(self, img):
        img_dict = {'vision': self.image_processor(img)}
        if self.cross_image_processor:
            img_dict.update({'cross': self.cross_image_processor(img)})
        return img_dict
    
    def process_text(self, answer, prompt):
        return self.text_processor(answer, prompt)
    
    def load_data(self, data_dir):
        all_files = find_all_files(data_dir, suffix=".jpg")
        print_rank0(f"find {len(all_files)} samples in all...")
        return all_files
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = self.data[index]  # 获得图片的绝对路径
        # img
        try:
            img = Image.open(data).convert('RGB')  # 载入图片
        except Exception as e:
            print_rank0(e, level=logging.WARNING)
            return {}
        img_dict = self.process_img(img)  # 图像加工
        # text
        label = data.split('/')[-1].split('.')[0]
        uni_key = label
        text_dict = self.process_text(label, "CAPTCHA:")
        if text_dict is None:
            print_rank0(f"Process text failed. Please check the max_target_length & max_source_length.\n The data is {data}", level=logging.WARNING)
            return {}
        # other attr
        ret = {**img_dict, **text_dict, "question_id": uni_key}
        return ret

1、图像数据处理

1、process_img(self, img)

图像处理最终为一个img_dict的字典,包含2部分处理,其源码如下:

    def process_img(self, img):
        img_dict = {'vision': self.image_processor(img)}
        if self.cross_image_processor:
            img_dict.update({'cross': self.cross_image_processor(img)})
        return img_dict

2、self.image_processor(img)

image_processor函数传递为blip2_image_processor_func_with_inputs函数

def blip2_image_processor_func_with_inputs(image_processor, image):
    return {'image': image_processor(image).unsqueeze(0), 'input_ids': torch.zeros(1, 1, dtype=torch.long), 'position_ids': None, 'attention_mask': torch.ones(1, 1, dtype=torch.long)}

2、文本数据处理

1、原始图像验证码文本数据处理

原始验证码数据是一个图片,其名字为验证码数字命名,在源码data表示图像绝对路径(/home/*/0a4Ovs8789.jpg)。因此,使用以下方式label = data.split(‘/’)[-1].split(‘.’)[0]即可获得文本(0a4Ovs8789),随后使用self.process_text函数即可实现验证码文本数据。

label = data.split('/')[-1].split('.')[0]  
uni_key = label
text_dict = self.process_text(label, "CAPTCHA:")

验证码数据图如下:
在这里插入图片描述

2、使用自己文本数据代码修改

假如一个图片对应一个数据json文件,其内容如下:

{
  "conversations": [
    {
      "role": "assistant",
      "content": "虽然无法从照片中确定他们的确切目的地或目的地,但这两名滑板运动员很可能正在使用公共交通工具前往滑板公园、休闲场所或其他可以练习滑板的地方。他们也可以在空闲时间携带滑板进行娱乐活动,往返于学校、工作或其他日常活动。或者,他们可能只是带着滑板在不同地点之间旅行,作为首选的交通方式。"
    }
  ]
}

代码内容是获取json文件相应内容:

        json_path = data.replace('images','labels_zh')[:-4]+'.json'
        label=self.read_json(json_path)
        label = label['captions'][0]['content']  # 获取描述内容
        uni_key = label
        text_dict = self.process_text(label, "CAPTCHA:")

读取json辅助代码如下:

    def read_json(self,json_root):
        import json
        with open(json_root, encoding='utf-8') as f:
            json_info = json.load(f)
        return json_info

最后文本将使用以下方式加工文本,我后期有时间在具体解读。

text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)  # 获得文本加工函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/392737.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CSS 多色阴影效果和旋转动画的加载指示器

<template><!-- 创建一个装载加载动画的容器 --><view class="loader"><!-- 内部阴影层,用于放置动态文本 --><view class="intern"></view><!-- 外部阴影层,包含旋转和颜色变化的圆形阴影 --><view class…

RK3399平台开发系列讲解(USB篇)USB 枚举和断开过程

🚀返回专栏总目录 文章目录 一、连接与检测二、USB设备枚举三、断开过程沉淀、分享、成长,让自己和他人都能有所收获!😄 📢介绍 USB 枚举/断开过程。 一、连接与检测 二、USB设备枚举 USB设备枚举一

第六节笔记:OpenCompass 大模型评测

视频链接&#xff1a;https://www.bilibili.com/video/BV1Gg4y1U7uc/?spm_id_from333.788&vd_source3bbd0d74033e31cbca9ee35e111ed3d1

手写myscrapy(二)

我们看一下scrapy的系统架构设计方法和思路&#xff1a; 模块化设计&#xff1a; Scrapy采用模块化设计&#xff0c;将整个系统划分为多个独立的模块&#xff0c;包括引擎&#xff08;Engine&#xff09;、调度器&#xff08;Scheduler&#xff09;、下载器&#xff08;Downl…

RIP协议详解

​RIP是最早的动态路由协议&#xff0c;虽然已经过时并且很少使用&#xff0c;但是可以通过学习RIP并且和ospf等现在正在使用的路由协议对比&#xff0c;了解其工作原理和过时原因&#xff0c;具有很强的学习性。 一、RIP协议简介 RIP&#xff08;Routing Information Protoc…

Vue22 Vue监测数据改变的原理_数组

实例 <!DOCTYPE html> <html><head><meta charset"UTF-8" /><title>Vue监测数据改变的原理_数组</title><!-- 引入Vue --><script type"text/javascript" src"../js/vue.js"></script>&…

如何避免发送HTTP请求

资料来源 : 小林coding 小林官方网站 : 小林coding (xiaolincoding.com) 如何避免发送HTTP请求? 这个思路你看到是不是觉得很奇怪&#xff0c;不发送 HTTP 请求&#xff0c;那客户端还怎么和服务器交互数据?小林你这不是要流氓嘛? 冷静冷静&#xff0c;你说的没错&#xf…

jmeter-12jmeter的录制功能

文章目录 什么情况下使用录制功能?操作流程具体设置如下观察结果什么情况下使用录制功能? 在测试过程中,很多时候可能会没有接口文档,这样你不知道请求方式,url,等等如何进行测试? jmeter提供了对应的录制功能。录制功能可以抓到具体的接口信息 操作流程 创建线程组 …

Pandas 数据处理:从基础到高级的完整指南【第84篇—Pandas 数据处理】

Pandas 数据处理&#xff1a;从基础到高级的完整指南 Pandas 是一个强大的数据分析工具&#xff0c;广泛应用于数据科学、机器学习和统计分析等领域。本文将介绍 Pandas 模块的基础知识&#xff0c;包括数据结构、数据导入、数据选择与过滤等方面&#xff0c;通过实际代码示例…

R语言课程论文-飞机失事数据可视化分析

数据来源&#xff1a;Airplane Crashes Since 1908 (kaggle.com) 代码参考&#xff1a;Exploring historic Air Plane crash data | Kaggle 数据指标及其含义 指标名 含义 Date 事故发生日期(年-月-日) Time 当地时间&#xff0c;24小时制&#xff0c;格式为hh:mm Locat…

Android 11以上获取不到第三方app是否安装

开年第一篇&#xff0c;处理了一下年前的小问题。 问题&#xff1a;本地app跳转到第三方app地图进行导航&#xff0c;获取不到第三方地图是否安装。 解决&#xff1a; 1.添加包名 This can be done by adding a <queries> element in the Android manifest.在app下的…

红队打靶练习:IMF: 1

目录 信息收集 1、arp 2、nmap 3、nikto 目录探测 gobuster dirsearch WEB 信息收集 get flag1 get flag2 get flag3 SQL注入 漏洞探测 脱库 get flag4 文件上传 反弹shell 提权 get flag5 get flag6 信息收集 1、arp ┌──(root㉿ru)-[~/kali] └─# a…

【力扣hot100】刷题笔记Day5

前言 回学校了&#xff0c;荒废了半天之后打算奋发图强猛猛刷题&#xff0c;找实习&#xff01;赚钱&#xff01;&#xff01; 560. 和为 K 的子数组 - 力扣&#xff08;LeetCode&#xff09; 前缀法 哈希表 这个题解解释比官方清晰&#xff0c;截个图方便看&#xff0c;另一…

慎投!2023年共124本SCI/SSCI被剔除汇总(附电子档下载目录)

2023年SCI/SSCI剔除期刊汇总 2023年3月20日&#xff0c;Web of Science核心期刊目录再次更新&#xff01;共有50本期刊被剔除出SCIE & SSCI期刊目录&#xff0c;其中大部分为Hindawi旗下期刊&#xff08;19本&#xff09;&#xff0c;引起不小的轰动&#xff01; 2023年全…

数据结构之时空复杂度

一、前言 1&#xff09;什么是数据结构 数据结构(Data Structure)是计算机存储、组织数据的方式&#xff0c;指相互之间存在一种或多种特定关系的数据元素的 集合。 2&#xff09;什么是算法 算法(Algorithm):就是定义良好的计算过程&#xff0c;他取一个或一组的值为输入&am…

计算机网络——14CDN

CDN 视频流化服务和CDN&#xff1a;上下文 视频流量&#xff1a;占据着互连网大部分的带宽 Netflix&#xff0c;YouTube&#xff1a;占据37%&#xff0c;16%的下行流量 挑战&#xff1a;规模性-如何服务~1B用户&#xff1f; 单个超级服务器无法提供服务&#xff08;为什么&am…

备战蓝桥杯---数学之博弈论基础1

目录 1.对称博弈 2.巴什博弈&#xff1a; 3.NIM博弈&#xff1a; 注意一个法则&#xff1a; 1.对称博弈 我们先看一个经典的例子&#xff1a; 下面是分析&#xff1a; 2.巴什博弈&#xff1a; 我们只要先手取1个&#xff0c;然后先手再去取5-刚刚后手的数字即可。 当石子数…

SHERlocked93 的 2021 年终总结

我还是和往年一样&#xff0c;总结发的又晚了一点&#xff0c;为什么又发这么晚呢&#xff0c;因为懒 年终总结 疫情之后时间时间过的太快了&#xff0c;不知道是不是只有我这样感觉。 四五月份去兰州玩了下&#xff08;其实是出差&#xff09;&#xff0c;终于看到了黄土高原&…

机器视觉与嵌入式技术:开拓自动驾驶和远程监控新视野

&#xff08;本文为简单介绍&#xff0c;观点源于网络&#xff09; 机器视觉系统是指利用计算机来模拟人眼的识别与判断。在自动驾驶和远程监控领域&#xff0c;机器视觉结合嵌入式技术的应用&#xff0c;不仅极大地提升了自动化水平&#xff0c;而且开辟了新的技术视野。 在…

迅为3A5000_7A2000开发板龙芯自主指令系统支持PCIE3.0、USB3.0、SATA3.0、HDMI、VGA等

性能强 采用全国产龙芯3A5000处理器&#xff0c;基于龙芯自主指令系统 (LoongArch)的LA464微结构&#xff0c;并进一步提升频率&#xff0c;降低功耗&#xff0c;优化性能。 桥片 采用龙芯 7A2000&#xff0c;支持PCIE 3.0、USB 3.0和 SATA 3.0.显示接口2 路、HDMI 和1路 VGA…