pytorch实现半监督学习

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

半监督学习(Semi-Supervised Learning,SSL)结合了有监督学习和无监督学习的特点,通常用于部分数据有标签、部分数据无标签的场景。其主要步骤如下:

1. 数据准备

  • 有标签数据(Labeled Data):数据集的一部分带有真实的类别标签。
  • 无标签数据(Unlabeled Data):数据集的另一部分没有标签,仅有特征信息。
  • 数据预处理:对数据进行清理、标准化、特征工程等处理,以保证数据质量。

2. 选择半监督学习方法

常见的半监督学习方法包括:

  • 基于生成模型(Generative Models):如高斯混合模型(GMM)、变分自编码器(VAE)。
  • 基于一致性正则化(Consistency Regularization):如 MixMatch、FixMatch,利用数据增强来约束模型预测一致性。
  • 基于伪标签(Pseudo-Labeling):先用模型预测无标签数据的类别,然后将高置信度的预测作为新标签加入训练。
  • 图神经网络(Graph-Based Methods):如 Label Propagation,通过构造数据之间的图结构传播标签信息。

3. 训练初始模型

  • 仅使用有标签数据训练一个初始模型。
  • 选择合适的损失函数,如交叉熵损失(Cross-Entropy Loss)或均方误差(MSE Loss)。
  • 训练过程中可以使用数据增强、正则化等优化策略。

4. 利用无标签数据增强训练

  • 伪标签方法:用初始模型对无标签数据进行预测,筛选高置信度样本,加入有标签数据训练。
  • 一致性正则化:对无标签数据进行不同变换,要求模型的预测结果一致。
  • 联合训练:构造有监督损失(Supervised Loss)和无监督损失(Unsupervised Loss),综合优化。

5. 模型迭代更新

  • 重新利用训练后的模型预测无标签数据,产生新的伪标签或调整模型参数。
  • 通过半监督策略不断优化模型,使其对无标签数据的预测更加稳定。

6. 评估和测试

  • 使用测试集(通常是有标签的数据)评估模型性能。
  • 选择合适的评估指标,如准确率(Accuracy)、F1-score、AUC-ROC 等。

7. 调优和部署

  • 根据实验结果调整超参数,如伪标签置信度阈值、学习率等。
  • 结合业务需求,将最终模型部署到实际应用中。

关键步骤:

  1. 初始化模型:首先使用有标签数据训练模型。
  2. 生成伪标签:用训练好的模型对无标签数据进行预测,生成伪标签。
  3. 结合有标签和伪标签数据进行训练:用带有标签和无标签(伪标签)数据一起训练模型。
  4. 迭代训练:不断迭代,使用更新的模型生成新的伪标签,进一步优化模型。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt


# 简化的神经网络模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3)  # 缩小卷积层的输出通道
        self.fc1 = nn.Linear(8 * 26 * 26, 10)  # 调整全连接层的输入和输出尺寸

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc1(x)
        return x


# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self, data, labels=None):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.labels is not None:
            return self.data[idx], self.labels[idx]
        else:
            return self.data[idx], -1  # 无标签数据


# 半监督训练函数
def pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device, threshold=0.95):
    model.train()
    labeled_loss_value = 0
    pseudo_loss_value = 0

    for (labeled_data, labeled_labels), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):
        labeled_data, labeled_labels = labeled_data.to(device), labeled_labels.to(device)
        unlabeled_data = unlabeled_data.to(device)

        # 1. 有标签数据训练
        optimizer.zero_grad()
        labeled_output = model(labeled_data)
        labeled_loss = F.cross_entropy(labeled_output, labeled_labels)
        labeled_loss.backward()

        # 2. 无标签数据伪标签生成
        unlabeled_output = model(unlabeled_data)
        probs = F.softmax(unlabeled_output, dim=1)
        max_probs, pseudo_labels = torch.max(probs, dim=1)

        # 伪标签置信度筛选
        pseudo_mask = max_probs > threshold  # 置信度大于阈值的数据作为伪标签
        if pseudo_mask.sum() > 0:
            pseudo_labels = pseudo_labels[pseudo_mask]
            unlabeled_data_pseudo = unlabeled_data[pseudo_mask]

            # 3. 使用伪标签数据进行训练(确保无标签数据参与反向传播)
            optimizer.zero_grad()  # 清除之前的梯度
            pseudo_output = model(unlabeled_data_pseudo)
            pseudo_loss = F.cross_entropy(pseudo_output, pseudo_labels)
            pseudo_loss.backward()  # 计算反向梯度

        optimizer.step()  # 更新模型参数

        # 累加损失用于展示
        labeled_loss_value += labeled_loss.item()
        if pseudo_mask.sum() > 0:
            pseudo_loss_value += pseudo_loss.item()

    return labeled_loss_value / len(labeled_loader), pseudo_loss_value / len(unlabeled_loader)


# 模拟数据
num_labeled = 1000
num_unlabeled = 5000
data_dim = (1, 28, 28)  # 28x28 灰度图像
num_classes = 10

labeled_data = torch.randn(num_labeled, *data_dim)
labeled_labels = torch.randint(0, num_classes, (num_labeled,))
unlabeled_data = torch.randn(num_unlabeled, *data_dim)

labeled_dataset = CustomDataset(labeled_data, labeled_labels)
unlabeled_dataset = CustomDataset(unlabeled_data)

labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)  # 缩小批量大小

# 模型、优化器和设备设置
device = torch.device("cpu")  # 临时使用 CPU
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练过程并记录损失
num_epochs = 10
labeled_losses = []
pseudo_losses = []

for epoch in range(num_epochs):
    labeled_loss, pseudo_loss = pseudo_labeling_training(model, labeled_loader, unlabeled_loader, optimizer, device)
    labeled_losses.append(labeled_loss)
    pseudo_losses.append(pseudo_loss)
    print(f"Epoch [{epoch + 1}/{num_epochs}] | Labeled Loss: {labeled_loss:.4f} | Pseudo Loss: {pseudo_loss:.4f}")

# 绘制损失曲线
plt.plot(range(num_epochs), labeled_losses, label='Labeled Loss')
plt.plot(range(num_epochs), pseudo_losses, label='Pseudo Label Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Losses Over Epochs')
plt.show()

# 展示伪标签生成效果(可视化一些样本的伪标签预测结果)
model.eval()
with torch.no_grad():
    sample_unlabeled_data = unlabeled_data[:10].to(device)
    output = model(sample_unlabeled_data)
    probs = F.softmax(output, dim=1)
    _, predicted_labels = torch.max(probs, dim=1)

    # 展示预测的标签
    print("Generated Pseudo Labels for Samples:")
    print(predicted_labels)

    # 假设这些是伪标签预测的图片
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    for i, ax in enumerate(axes.flat):
        # 将tensor转换为NumPy数组
        img = sample_unlabeled_data[i].cpu().numpy().squeeze()  # 转为NumPy数组
        ax.imshow(img, cmap='gray')  # 使用灰度显示图像
        ax.set_title(f"Pred: {predicted_labels[i].item()}")
        ax.axis('off')
    plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/965105.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

力扣1022. 从根到叶的二进制数之和(二叉树的遍历思想解决)

Problem: 1022. 从根到叶的二进制数之和 文章目录 题目描述思路复杂度Code 题目描述 思路 遍历思想(利用二叉树的先序遍历) 1.在先序遍历的过程中,用一个变量path记录并更新其经过的路径上的值,当遇到根节点时再将其加到结果值res上; 2.该题…

.NET 中实现生产者-消费者模型,BlockingCollection<T> 和 Channel<T>使用示例

一、方案对比&#xff1a;不同线程安全集合的适用场景 二、推荐方案及示例代码 方案 1&#xff1a;使用 BlockingCollection&#xff08;同步模型&#xff09; public class QueueDemo {private readonly BlockingCollection<int> _blockingCollection new BlockingCo…

C_位运算符及其在单片机寄存器的操作

C语言的位运算符用于直接操作二进制位&#xff0c;本篇简单结束各个位运算符的作业及其在操作寄存器的应用场景。 一、位运算符的简单说明 1、按位与运算符&#xff08;&&#xff09; 功能&#xff1a;按位与运算符对两个操作数的每一位执行与操作。如果两个对应的二进制…

Redis入门概述

1.1、Redis是什么 Redis&#xff1a;官网 高性能带有数据结构的Key-Value内存数据库 Remote Dictionary Server&#xff08;远程字典服务器&#xff09;是完全开源的&#xff0c;使用ANSIC语言编写遵守BSD协议&#xff0c;例如String、Hash、List、Set、SortedSet等等。数据…

个人毕业设计--基于HarmonyOS的旅行助手APP的设计与实现(挖坑)

在行业混了短短几年&#xff0c;却总感觉越混越迷茫&#xff0c;趁着还有心情学习&#xff0c;把当初API9 的毕业设计项目改成API13的项目。先占个坑&#xff0c;把当初毕业设计的文案搬过来 摘要&#xff1a;HarmonyOS&#xff08;鸿蒙系统&#xff09;是华为公司推出的面向全…

C++11详解(二) -- 引用折叠和完美转发

文章目录 2. 右值引用和移动语义2.6 类型分类&#xff08;实践中没什么用&#xff09;2.7 引用折叠2.8 完美转发2.9 引用折叠和完美转发的实例 2. 右值引用和移动语义 2.6 类型分类&#xff08;实践中没什么用&#xff09; C11以后&#xff0c;进一步对类型进行了划分&#x…

车载以太网__传输层

车载以太网中&#xff0c;传输层和实际用的互联网相差无几。本篇文章对传输层中的IP进行介绍 目录 什么是IP&#xff1f; IP和MAC的关系 IP地址分类 私有IP NAT DHCP 为什么要防火墙穿透&#xff1f; 广播 本地广播 直接广播 本地广播VS直接广播 组播 …

大数据学习之Spark分布式计算框架RDD、内核进阶

一.RDD 28.RDD_为什么需要RDD 29.RDD_定义 30.RDD_五大特性总述 31.RDD_五大特性1 32.RDD_五大特性2 33.RDD_五大特性3 34.RDD_五大特性4 35.RDD_五大特性5 36.RDD_五大特性总结 37.RDD_创建概述 38.RDD_并行化创建 演示代码&#xff1a; // 获取当前 RDD 的分区数 Since ( …

第一性原理:游戏开发成本的思考

利润 营收-成本 营收定价x销量x分成比例 销量 曝光量x 点击率x &#xff08;购买率- 退款率&#xff09; 分成比例 100%- 平台抽成- 税- 引擎费- 发行抽成 成本开发成本运营成本 开发成本 人工外包办公地点租金水电设备折旧 人工成本设计成本开发成本迭代修改成本后续内容…

MLA 架构

注&#xff1a;本文为 “MLA 架构” 相关文章合辑。 未整理去重。 DeepSeek 的 MLA 架构 原创 老彭坚持 产品经理修炼之道 2025 年 01 月 28 日 10:15 江西 DeepSeek 的 MLA&#xff08;Multi-head Latent Attention&#xff0c;多头潜在注意力&#xff09;架构 是一种优化…

数据结构-堆和PriorityQueue

1.堆&#xff08;Heap&#xff09; 1.1堆的概念 堆是一种非常重要的数据结构&#xff0c;通常被实现为一种特殊的完全二叉树 如果有一个关键码的集合K{k0,k1,k2,...,kn-1}&#xff0c;把它所有的元素按照完全二叉树的顺序存储在一个一维数组中&#xff0c;如果满足ki<k2i…

BUUCTF_[安洵杯 2019]easy_web(preg_match绕过/MD5强碰撞绕过/代码审计)

打开靶场&#xff0c;出现下面的静态html页面&#xff0c;也没有找到什么有价值的信息。 查看页面源代码 在url里发现了img传参还有cmd 求img参数 这里先从img传参入手&#xff0c;这里我发现img传参好像是base64的样子 进行解码&#xff0c;解码之后还像是base64的样子再次进…

Linux的简单使用和部署4asszaaa0

一.部署 1 环境搭建方式主要有四种: 1. 直接安装在物理机上.但是Linux桌面使用起来非常不友好.所以不建议.[不推荐]. 2. 使用虚拟机软件,将Linux搭建在虚拟机上.但是由于当前的虚拟机软件(如VMWare之类的)存在⼀些bug,会导致环境上出现各种莫名其妙的问题比较折腾.[非常不推荐…

RK3566-移植5.10内核Ubuntu22.04

说明 记录了本人使用泰山派&#xff08;RK3566&#xff09;作为平台并且成功移植5.10.160版本kernel和ubuntu22.04&#xff0c;并且成功配置&连接网络的完整过程。 本文章所用ubuntu下载地址&#xff1a;ubuntu-cdimage-ubuntu-base-releases-22.04-release安装包下载_开源…

二级C语言题解:十进制转其他进制、非素数求和、重复数统计

目录 一、程序填空&#x1f4dd; --- 十进制转其他进制 题目&#x1f4c3; 分析&#x1f9d0; 二、程序修改&#x1f6e0;️ --- 非素数求和 题目&#x1f4c3; 分析&#x1f9d0; 三、程序设计&#x1f4bb; --- 重复数统计 题目&#x1f4c3; 分析&#x1f9d0; 前言…

UE求职Demo开发日志#22 显示人物信息,完善装备的穿脱

1 创建一个人物信息显示的面板&#xff0c;方便测试 简单弄一下&#xff1a; UpdateInfo函数&#xff1a; 就是获取ASC后用属性更新&#xff0c;就不细看了 2 实现思路 在操作目标为装备栏&#xff0c;或者操作起点为装备栏时&#xff0c;交换前先判断能否交换&#xff08;只…

在游戏本(6G显存)上本地部署Deepseek,运行一个14B大语言模型,并使用API访问

在游戏本6G显存上本地部署Deepseek&#xff0c;运行一个14B大语言模型&#xff0c;并使用API访问 环境说明环境准备下载lmstudio运行lmstudio 下载模型从huggingface.co下载模型 配置模型加载模型测试模型API启动API服务代码测试 deepseek在大语言模型上的进步确实不错&#xf…

专业学习|一文了解并实操自适应大邻域搜索(讲解代码)

一、自适应大邻域搜索概念介绍 自适应大邻域搜索&#xff08;Adaptive Large Neighborhood Search&#xff0c;ALNS&#xff09;是一种用于解决组合优化问题的元启发式算法。以下是关于它的详细介绍&#xff1a; -自适应大领域搜索的核心思想是&#xff1a;破坏解、修复解、动…

记录一下 在Mac下用pyinstallter 打包 Django项目

安装: pip install pyinstaller 在urls.py from SheepMasterOneToOne import settings from django.conf.urls.static import staticurlpatterns [path("admin/", admin.site.urls),path(generate_report/export/, ReportAdmin(models.Report, admin.site).generat…

如何在Intellij IDEA中识别一个文件夹下的多个Maven module?

目录 问题描述 理想情况 手动添加Module&#xff0c;配置Intellij IDEA的Project Structure 问题描述 一个文件夹下有多个Maven项目&#xff0c;一个一个开窗口打开可行但是太麻烦。直接open整个文件夹会发现Intellij IDEA默认可能就识别一个或者几个Maven项目&#xff0c;如…