使用SwinUnet训练自己的数据集

参考博文:

数据集准备

遥感图像多类别语义分割,总共分为7类(包括背景) image: label_rgb label(这里并不是全黑,其中的类别取值为0,1,2,3,4,5,6),此后的训练使用的也是这样的数据

数据地址 百度云: 提取码:2022

数据集处理

数据集的image和label,这个数据集应该提供了rgb格式标签和包含0,1,2,3,4,5,6值的标签,SwinUNet使用的是包含0,1,2,3,4,5,6的标签图像;

1. 数据集

数据集存放在SwinUNet根目录下,image中是原图像,label中是标签图像(共7类,其标签取值为0,1,2,3,4,5,6,7); 如果使用其他数据集,要注意标签的取值。比如如果是二分类。即标签0或255,需要换成0或1

—SwinUNet ---------configs ---------img_datas ---------------train --------------------image --------------------label ---------------test --------------------image --------------------label

2. 在SwinUnet根目录下创建npz.py文件,运行npz.py文件

import glob
import cv2
import numpy as np
import os

def npz(im, la, s):
    images_path = im
    labels_path = la
    path2 = s
    images = os.listdir(images_path)
    for s in images:
        image_path = os.path.join(images_path, s)
        label_path = os.path.join(labels_path, s)

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
		# 标签由三通道转换为单通道
        label = cv2.imread(label_path, flags=0)
        # 保存npz文件 
        np.savez(path2+s[:-4]+".npz",image=image,label=label)

npz(./img_datas/train/image/, ./img_datas/train/label/, ./data/Synapse/train_npz)

npz(./img_datas/test/image/, ./img_datas/test/label/, ./data/Synapse/test_vol_h5)

3. 在SwinUnet根目录下创建txt.py文件,运行txt.py文件

目的是生成./list/list_Synapse/train.txt和./list/list_Synapse/test_vol.txt文件

import os
def write_name(np, tx):
    #npz文件路径
    files = os.listdir(np)
    #txt文件路径
    f = open(tx, w)
    for i in files:
        #name = i.split(\)[-1]
        name = i[:-4]+

        f.write(name)
        
write_name(./data/Synapse/train_npz, ./lists/lists_Synapse/train.txt)
write_name(./data/Synapse/test_vol_h5, ./lists/lists_Synapse/test_vol.txt)

4. 下载预训练权重,放在SwinUnet目录下的pretrained_ckpt文件夹下

链接: 提取码:2022

修改网络

1. 修改train.py文件

比较重要的是类别数量,其他视情况而定

2. 修改./datasets/dataset_synapse.py文件

3. 修改trainer.py文件

此处不知道为什么

4. 运行代码

这些信息可以作为超参传入,如果不能,那么可以使用default=的方式写入默认值 如果设置好啦默认值,那么运行python train.py就可以啦

经验分享 程序员 微信小程序 职场和发展