P-tuning介绍
代码实现
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=2)
def train(tokenize, model, prompt_lenght, prompt, data):
# 冻结Bert参数
for param in model.bert.parameters():
param.requires_grad = False
# 优化器
optimizer = torch.optim.Adam([prompt], lr=1e-3)
# 训练循环
num_epochs = 8
losses = []
for epoch in range(num_epochs):
total_loss = 0.0
for text, label in data:
# 处理输入和标签
inputs = tokenizer(text, return_tensors='pt')
labels = torch.tensor([label]) # 标签,形状为 [batch_size]
# 访问 BERT 的嵌入层
bert_model = model.bert
input_ids = inputs['input_ids']
# 获取输入标记的嵌入表示
with torch.no_grad():
input_embeddings = bert_model.embeddings(input_ids)
# 扩展和拼接提示向量和输入嵌入表示
prompt_embeddings = prompt.unsqueeze(0).expand(input_ids.size(0), -1, -1) # unsqueeze(0):新增第一个维度。expand(input_ids.size(0), -1,- 1):对第一个维度按照input_ids[0]的大小进行扩展,-1表示自动计算维度大小。
prompted_input = torch.cat((prompt_embeddings, input_embeddings), dim=1)
# 前向传播
attention_mask = torch.cat((torch.ones(prompt_embeddings.size()[:2], dtype=torch.long), inputs['attention_mask']), dim=1)
outputs = bert_model(inputs_embeds=prompted_input, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
# 分类头
logits = model.classifier(sequence_output[:, prompt_length:, :]) # 跳过提示向量部分
# 确保logits的形状与labels匹配
logits = logits[:, 0, :] # 只取第一个token的logits(即[CLS] token)
# 计算损失
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1)) # 确保 logits 和 labels 的形状匹配
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
losses.append(total_loss)
print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss/len(data)}')
torch.save(prompt, 'path_to_trained_prompt.pt')
return losses
def plot_loss(losses):
plt.figure()
plt.plot(losses)
def predict_classify(tokenize, model, prompt_length, trained_prompt, data):
predict_list = []
for input_text in data:
inputs = tokenizer(input_text, return_tensors='pt')
# 访问 BERT 的嵌入层
bert_model = model.bert
input_ids = inputs['input_ids']
# 获取输入标记的嵌入表示
with torch.no_grad():
input_embeddings = bert_model.embeddings(input_ids)
# 扩展和拼接提示向量和输入嵌入表示
prompt_embeddings = trained_prompt.unsqueeze(0).expand(input_ids.size(0), -1, -1)
prompted_input = torch.cat((prompt_embeddings, input_embeddings), dim=1)
# 构建新的注意力掩码
attention_mask = torch.cat((torch.ones(prompt_embeddings.size()[:2], dtype=torch.long), inputs['attention_mask']), dim=1)
# 前向传播进行推理
with torch.no_grad():
outputs = bert_model(inputs_embeds=prompted_input, attention_mask=attention_mask)
sequence_output = outputs.last_hidden_state
# 分类头
logits = model.classifier(sequence_output[:, prompt_length:, :]) # 跳过提示向量部分
logits = logits[:, 0, :] # 只取第一个token的logits(即[CLS] token)
# 获取预测结果
predicted_label = torch.argmax(logits, dim=-1).item()
print(f"Input data: {input_text}, Predicted label: {predicted_label}")
predict_list.append(predicted_label)
return predict_list
# p-tuning训练
# 定义可学习的提示向量
prompt_length = 5
prompt = torch.nn.Parameter(torch.randn(prompt_length, model.config.hidden_size))
# 训练集
data = [("This movie is great", 1), ("This movie is bad", 0)]
# 训练
losses = train(tokenizer, model, prompt_length, prompt, data)
# 绘制Loss曲线
plot_loss(losses)
# p-tuning预测
prompt_length = 5
trained_prompt = torch.load('path_to_trained_prompt.pt') # 加载训练好的提示嵌入
input_text = ["This movie is good", "This movie is bad", "This movie is not good"]
predict_list = predict_classify(tokenizer, model, prompt_length, trained_prompt, input_text)
拓展文章:第7章 大模型之Adaptation