将FER数据集处理成灰度图片数据-Python代码
因为对Python的一些操作不是很熟悉,最近又想自己搭建一个FER(Facial Expression Recognition)系统,所以今天稍微花了点时间看了一下Python对于csv文件以及对于Image的IO操作,简单处理了一下FER2013数据集。
这个数据集的数据是存放到csv文件中的,因此需要使用到基本的csv操作。FER数据集比较好找,就不在这贴它的链接了。
下面直接放了代码,对基本流程进行了注释。记录下学习过程。后面也会将自己学习过程中搭建的FER网络模型分享出来。
# _*_coding:utf-8_*_ import os import csv import pickle import numpy as np from PIL import Image import matplotlib.pyplot as plt # 获取项目根路径 current_path = os.path.abspath(os.path.dirname(__file__)) root_path = os.path.dirname(current_path) # 进行FER数据集的预处理,还原成图片 class FERDataProcess(object): def __init__(self): self.data_dir = os.path.join(root_path, data) self.data_source = os.path.join(self.data_dir, fer2013.csv) # 数据源文件 self.data_image_dir = os.path.join(self.data_dir, fer_images) # 存储图片数据的文件夹 self.train_data_path = os.path.join(self.data_image_dir, train) self.test_data_path = os.path.join(self.data_image_dir, test) self.validate_data_path = os.path.join(self.data_image_dir, validation) self.pickle_data_path = os.path.join(self.data_dir, pickle) # pickle数据文件夹 def process_to_image(self): """ 将csv文件按照标签以及用途还原成图片文件 :return: None """ # 创建文件夹 if not os.path.exists(self.train_data_path): os.makedirs(self.train_data_path) if not os.path.exists(self.test_data_path): os.makedirs(self.test_data_path) if not os.path.exists(self.validate_data_path): os.makedirs(self.validate_data_path) # 读取csv文件 if not os.path.exists(self.data_source): raise FileNotFoundError(File: {} not exist!.format(self.data_source)) with open(self.data_source) as f: data_csv = csv.reader(f) next(data_csv) # 忽略第一行 index = 1 for (label, pixels, usage) in data_csv: # 区分文件夹 if usage == Training: label_path = os.path.join(self.train_data_path, label) if not os.path.exists(label_path): os.makedirs(label_path) elif usage == PublicTest: label_path = os.path.join(self.validate_data_path, label) if not os.path.exists(label_path): os.makedirs(label_path) elif usage == PrivateTest: label_path = os.path.join(self.test_data_path, label) if not os.path.exists(label_path): os.makedirs(label_path) # 构建图片像素点 reshape_pixels = np.asarray([float(p) for p in pixels.split()]).reshape(48, 48) img = Image.fromarray(reshape_pixels).convert(L) image_name = os.path.join(label_path, {}.jpg.format(index)) print(image_name) img.save(image_name) index += 1
使用这段代码时,实例化 FERDataProcess类对象,调用对象的process_to_image方法即可。
这段代码了并没有统计训练、测试、验证数据集具体有多少记录,代码比较简单,有需要的小伙伴可以自己加上。