unet脑肿瘤分割完整代码

U-net脑肿瘤分割完整代码

    • 代码目录
    • 数据集
    • 网络
    • 训练
    • 测试

代码目录

在这里插入图片描述

数据集

在这里插入图片描述
https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation

dataset.py

在这里插入代码片import os
import numpy as np
import glob
from PIL import Image
import cv2
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import matplotlib.pyplot as plt

kaggle_3m='./kaggle_3m/'
dirs=glob.glob(kaggle_3m+'*')
#print(dirs)
#os.listdir('./kaggle_3m\\TCGA_HT_A61B_19991127')
data_img=[]
data_label=[]
for subdir in dirs:
    dirname=subdir.split('\\')[-1]
    for filename in os.listdir(subdir):
        img_path=subdir+'/'+filename #图片的绝对路径
        if 'mask' in img_path:
            data_label.append(img_path)
        else:
            data_img.append(img_path)
#data_img[:5] #前几张图 和标签是否对应
#data_label[:5]
data_imgx=[]
for i in range(len(data_label)):#图片和标签对应
    img_mask=data_label[i]
    img=img_mask[:-9]+'.tif'
    data_imgx.append(img)
#data_imgx
data_newimg=[]
data_newlabel=[]
for i in data_label:#获取只有病灶的数据
    value=np.max(cv2.imread(i))
    try:
        if value>0:
            data_newlabel.append(i)
            i_img=i[:-9]+'.tif'
            data_newimg.append(i_img)
    except:
        pass
#查看结果
#data_newimg[:5]
#data_newlabel[:5]
im=data_newimg[20]
im=Image.open(im)
#im.show(im)
im=data_newlabel[20]
im=Image.open(im)
#im.show(im)
#print("可用数据:")
#print(len(data_newlabel))
#print(len(data_newimg))
#数据转换
train_transformer=transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])
test_transformer=transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])
class BrainMRIdataset(Dataset):
    def __init__(self, img, mask, transformer):
        self.img = img
        self.mask = mask
        self.transformer = transformer

    def __getitem__(self, index):
        img = self.img[index]
        mask = self.mask[index]

        img_open = Image.open(img)
        img_tensor = self.transformer(img_open)

        mask_open = Image.open(mask)
        mask_tensor = self.transformer(mask_open)

        mask_tensor = torch.squeeze(mask_tensor).type(torch.long)

        return img_tensor, mask_tensor

    def __len__(self):
        return len(self.img)
s=1000#划分训练集和测试集
train_img=data_newimg[:s]
train_label=data_newlabel[:s]
test_img=data_newimg[s:]
test_label=data_newlabel[s:]
#加载数据
train_data=BrainMRIdataset(train_img,train_label,train_transformer)
test_data=BrainMRIdataset(test_img,test_label,test_transformer)

dl_train=DataLoader(train_data,batch_size=4,shuffle=True)
dl_test=DataLoader(test_data,batch_size=4,shuffle=True)

img,label=next(iter(dl_train))
plt.figure(figsize=(12,8))
for i,(img,label) in enumerate(zip(img[:4],label[:4])):
    img=img.permute(1,2,0).numpy()
    label=label.numpy()
    plt.subplot(2,4,i+1)
    plt.imshow(img)
    plt.subplot(2,4,i+5)
    plt.imshow(label)

网络

在这里插入图片描述
model.py


import torch
import torch.nn as nn


class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels,
                                      kernel_size=3, padding=1),
                            nn.ReLU(inplace=True),
                            nn.Conv2d(out_channels, out_channels,
                                      kernel_size=3, padding=1),
                            nn.ReLU(inplace=True)
            )
        self.pool = nn.MaxPool2d(kernel_size=2)
    def forward(self, x, is_pool=True):
        if is_pool:
            x = self.pool(x)
        x = self.conv_relu(x)
        return x


class Upsample(nn.Module):
    def __init__(self, channels):
        super(Upsample, self).__init__()
        self.conv_relu = nn.Sequential(
            nn.Conv2d(2 * channels, channels,
                      kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels,
                      kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.upconv_relu = nn.Sequential(
            nn.ConvTranspose2d(channels,
                               channels // 2,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv_relu(x)
        x = self.upconv_relu(x)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 1024)

        self.up = nn.Sequential(
            nn.ConvTranspose2d(1024,
                               512,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.ReLU(inplace=True)
        )

        self.up1 = Upsample(512)
        self.up2 = Upsample(256)
        self.up3 = Upsample(128)

        self.conv_2 = Downsample(128, 64)
        self.last = nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, x):
        x1 = self.down1(x, is_pool=False)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x5 = self.up(x5)

        x5 = torch.cat([x4, x5], dim=1)  # 32*32*1024
        x5 = self.up1(x5)  # 64*64*256)
        x5 = torch.cat([x3, x5], dim=1)  # 64*64*512
        x5 = self.up2(x5)  # 128*128*128
        x5 = torch.cat([x2, x5], dim=1)  # 128*128*256
        x5 = self.up3(x5)  # 256*256*64
        x5 = torch.cat([x1, x5], dim=1)  # 256*256*128

        x5 = self.conv_2(x5, is_pool=False)  # 256*256*64

        x5 = self.last(x5)  # 256*256*3
        return x5

if __name__ == '__main__':
    x = torch.rand([8, 3, 256, 256])
    model = Net()
    y = model(x)

训练

train.py

import torch as t
import torch.nn as nn
from tqdm import tqdm  #进度条
import model
from dataset import *


device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

train_data=BrainMRIdataset(train_img,train_label,train_transformer)
test_data=BrainMRIdataset(test_img,test_label,test_transformer)

dl_train=DataLoader(train_data,batch_size=4,shuffle=True)
dl_test=DataLoader(test_data,batch_size=4,shuffle=True)

model = model.Net()
img,label=next(iter(dl_train))
model=model.to('cuda')
img=img.to('cuda')
pred=model(img)
label=label.to('cuda')
loss_fn=nn.CrossEntropyLoss()#交叉熵损失函数
loss_fn(pred,label)
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)
def train_epoch(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    epoch_iou = [] #交并比

    net=model.train()
    for x, y in tqdm(testloader):
        x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()

            intersection = torch.logical_and(y, y_pred)
            union = torch.logical_or(y, y_pred)
            batch_iou = torch.sum(intersection) / torch.sum(union)
            epoch_iou.append(batch_iou.item())

    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / (total * 256 * 256)

    test_correct = 0
    test_total = 0
    test_running_loss = 0
    epoch_test_iou = []

    t.save(net.state_dict(), './Results/weights/unet_weight/{}.pth'.format(epoch))

    model.eval()
    with torch.no_grad():
        for x, y in tqdm(testloader):
            x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()

            intersection = torch.logical_and(y, y_pred)#预测值和真实值之间的交集
            union = torch.logical_or(y, y_pred)#预测值和真实值之间的并集
            batch_iou = torch.sum(intersection) / torch.sum(union)
            epoch_test_iou.append(batch_iou.item())

    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / (test_total * 256 * 256)#预测正确的值除以总共的像素点

    print('epoch: ', epoch,
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'IOU:', round(np.mean(epoch_iou), 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3),
          'test_iou:', round(np.mean(epoch_test_iou), 3)
          )

    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc


if __name__ == "__main__":
    epochs=20
    for epoch in range(epochs):
        train_epoch(epoch,
                    model,
                    dl_train,
                    dl_test)


在这里插入图片描述
只跑了20个epoch

测试

test.py

import torch as t
import torch.nn as nn
import model
from dataset import *
import matplotlib.pyplot as plt

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

train_data=BrainMRIdataset(train_img,train_label,train_transformer)
test_data=BrainMRIdataset(test_img,test_label,test_transformer)

dl_train=DataLoader(train_data,batch_size=4,shuffle=True)
dl_test=DataLoader(test_data,batch_size=4,shuffle=True)

model = model.Net()
img,label=next(iter(dl_train))
model=model.to('cuda')
img=img.to('cuda')
pred=model(img)
label=label.to('cuda')
loss_fn=nn.CrossEntropyLoss()
loss_fn(pred,label)
optimizer=torch.optim.Adam(model.parameters(),lr=0.0001)
def test():
    image, mask = next(iter(dl_test))
    image=image.to('cuda')
    net = model.eval()
    net.to(device)
    net.load_state_dict(t.load("./Results/weights/unet_weight/18.pth"))
    pred_mask = model(image)
    pred_mask=pred_mask
    mask=torch.squeeze(mask)
    pred_mask=pred_mask.cpu()
    num=4
    plt.figure(figsize=(10, 10))
    for i in range(num):
        plt.subplot(num, 4, i*num+1)
        plt.imshow(image[i].permute(1,2,0).cpu().numpy())
        plt.subplot(num, 4, i*num+2)
        plt.imshow(mask[i].cpu().numpy(),cmap='gray')#标签
        plt.subplot(num, 4, i*num+3)
        plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy(),cmap='gray')#预测
    plt.show()


if __name__ == "__main__":
    test()

模型分割效果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/320670.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何在“Microsoft Visual Studio”中使用OpenCV构建应用程序

我在这里描述的所有内容都将应用于 OpenCV 的界面。我首先假设您已经阅读并成功完成了 Windows 中的安装教程。因此,在进一步操作之前,请确保您有一个包含 OpenCV 头文件和二进制文件的 OpenCV 目录,并且您已按照此处所述设置环境变量 设置 O…

多智能体强化学习(概念知识,不涉及具体算法)

目录 一、前置知识1.factored value function2.partially observable MDP (POMDP) problem.2.2 Decentralized-POMDP problem2.3 networked decentralized partially observable Markov decision processes (ND-POMDP) problem2.4 上述两种算法的区别 3. Mean Field Multi-Agen…

深信服超融合HCI版本升级,6.0.0R5升级至6.8.0R2

超融合升级,需要满足以下条件及前期准备: 确认HCI的升级序列号有效升级时长大概在一个半小时,安全起见,需预留至少三至四小时窗口期升级前,需要将所有虚拟机关机,涉及到业务无法访问,需提前通知…

redis中的string相关的部分命令

redis命令手册 redis中文官网查看文档 挨个进行输出调试 Redis Setnx 命令 Redis Getrange 命令 Redis Mset 命令 redis 127.0.0.1:6379> MSET key1 "Hello" key2 "World" OK redis 127.0.0.1:6379> GET key1 "Hello" redis 127.0.0.1:…

datavrap-各种各样的条形图(含详细操作步骤)

静态条形图:正确设置数据即可,导出的图形不会随着时间变化 最普通的静态条形图 黑色系风格的静态条形图 动态条形图:导出的图形会随着时间变化 普通的动态条形图 带数字滚动效果的动态条形图 简单的Top排行榜动态条形图 格式更丰富的Top排行榜…

牛客周赛 Round 28 解题报告 | 珂学家 | 组合数学 + 离散化树状数组

前言 整体评价 还是E稍微有点意思&#xff0c;新周赛好像比预期要简单一些, _. 欢迎关注 珂朵莉 牛客周赛专栏 珂朵莉 牛客小白月赛专栏 A. 小红的新周赛 思路: 模拟 #include <bits/stdc.h>using namespace std;int main() {int res 0;for (int i 0; i < 6; i…

【算法练习】leetcode算法题合集之数组和哈希表篇

重建数组&#xff08;高频&#xff09; LeetCode283.移动零 LeetCode283.移动零 双指针&#xff0c;记录已经处理好的序列的尾部 class Solution {public void moveZeroes(int[] nums) {int k 0;for (int i 0; i < nums.length; i) {if (nums[i] ! 0) {swap(nums, i, k)…

【前后端的那些事】开源!前后端环境搭建+树形结构表格实现

文章目录 1. 前后端项目环境搭建2. table-tree2.1 后端准备2.2 前端准备 前言&#xff1a;最近写项目&#xff0c;发现了一些很有意思的功能&#xff0c;想写文章&#xff0c;录视频把这些内容记录下。但这些功能太零碎&#xff0c;如果为每个功能都单独搭建一个项目&#xff0…

Python之Matplotlib绘图调节清晰度

Python之Matplotlib绘图调节清晰度 文章目录 Python之Matplotlib绘图调节清晰度引言解决方案dpi是什么&#xff1f;效果展示总结 引言 使用python中的matplotlib.pyplot绘图的时候&#xff0c;如果将图片显示出来&#xff0c;或者另存为图片&#xff0c;常常会出现清晰度不够的…

前端js写数据结构与算法

1、什么是数据结构与算法 数据结构&#xff1a;是指数据对象中数据元素之间的相互关系。包括集合结构、线性结构、树形结构、图形结构。 算法&#xff1a;解决问题的思路。 2、时间复杂度 1.是什么? 执行当前算法所“花费的时间” 2.干什么? 在写代码的过程中&#xf…

网工内推 | 信息安全主管,CISP/CISSP认证优先,最高25K

01 武汉华康世纪医疗股份有限公司 招聘岗位&#xff1a;网络安全主管 职责描述&#xff1a; 1、推进公司信息/网络安全管理体系规划、建设、持续改进&#xff0c;促进信息安全管理的推行落地,保障网络、系统与数据安全&#xff1b; 2、维护管理信息/网络管理软件&#xff0c;设…

CSP网络结构实战 - 降低计算量的特征融合方式

CSP网络结构实战 - 降低计算量的特征融合方式 CSP网络结构实战 - 降低计算量的特征融合方式0. 引言1. CSP网络结构简介1.1 核心思想1.2 解决的问题 2. 实验验证2.1 CSP网络模型构建2.2 数据读取与预处理2.3 模型训练与验证 3. 对比实验4. 结果与总结 CSP网络结构实战 - 降低计算…

RT-DETR算法优化改进:多层次特征融合(SDI)结合PConv、DualConv、GSConv,实现二次创新 | UNet v2最新论文

💡💡💡本文独家改进:多层次特征融合(SDI)高效结合DualConv、PConv、GSConv等实现二次创新 1)替代原始的Concat; RT-DETR魔术师专栏介绍: https://blog.csdn.net/m0_63774211/category_12497375.html ✨✨✨魔改创新RT-DETR 🚀🚀🚀引入前沿顶会创新(CVPR…

从零开始做题:逆向wdb_2018_2nd_easyfmt

1.题目信息 2.解题分析 格式化字符串漏洞 如何确定偏移 Do you know repeater? 输入AAAA.%p.%p.%p.%p.%p.%p.%p.%p.%p.%p.%p.%p. 输出AAAA.0xffffd658.0x64.0xf7ffdc08.0xf7ffcd00.0xffffd77c.0x41414141.0x2e70252e.0x252e7025.0x70252e70.0x2e70252e.0x252e7025.0x70252…

【数据结构】排序算法

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:数据结构 ⚙️操作环境:Visual Studio 2022 目录 &#x1f38f;排序的定义 &#x1f38f;排序的稳定性 &#x1f4cc;稳定性的定义 &#x1f4cc;稳定性的意义 &#x1f38f;内排序与外排序 &#x1f38f;八大内排…

Linux环境基础开发工具的使用(上)

文章目录 Linux 软件包管理器 yum什么是软件包关于rzsz查看软件包安装软件卸载软件 Linux编辑器 - vimvim的基本概念vim下各模式的切换vim命令模式各命令汇总vim底行模式各命令汇总 配置vim Linux 软件包管理器 yum 什么是软件包 在Linux下安装软件, 一个通常的办法是下载到程…

Vue实战:两种方式创建Vue项目

文章目录 一、实战概述二、实战步骤&#xff08;一&#xff09;安装Vue CLI脚手架1、从Node.js官网下载LTS版本2、安装Node.js到指定目录3、配置Node.js环境变量4、查看node版本5、查看npm版本6、安装Vue Cli脚手架7、查看Vue Cli版本 &#xff08;二&#xff09;命令行方式构建…

Git与VScode联合使用详解

目录 Git与VScode联合使用 方式一 1. 用vscode打开文件夹&#xff0c;如图点击初始化仓库&#xff0c;把此仓库初始为git仓库。 2. 提交文件到本地仓库 3. vscode与github账号绑定 4. 在github中建立远程仓库 5. 本地仓库与远程仓库绑定 方式二 1. 在github上建立远程仓…

魅族MX4pro系统升级、降级

网上的教程都是按住开机键音量上或者下键&#xff0c;但是我按了没用&#xff0c;还是直接点击压缩包管用。 下载系统 官网地址&#xff08;所有手机固件&#xff09;&#xff1a;https://flyme.cn/firmware.html 官方魅族mx4Pro系统&#xff1a;https://flyme.cn/firmwarelis…

通过本质看现象:关于Integer受内部初始化赋值范围限制而出现的有趣现象

文/朱季谦 这是我很多年前的第一篇技术博客&#xff0c;当时作为一名技术小菜鸟&#xff0c;总体而言显得很拙见&#xff0c;但也算是成长路上的一个小脚印&#xff0c;希望能在以后的日子里&#xff0c;可以对JAVA技术有一个更加深入的思考与认识。 前几天我在逛论坛的时候&a…