数据集cifar10处理
cifar10数据集是比较难训练和分类识别的,对于GAN模型生成cifar10数据对其特性也是一个不小的挑战。如果GAN模型可以很好的生成 cifar10数据图片的话,这个GAN的生成能力已经是很不错的了。今天我们说说如何把下载下来的cifar10数据集输出到GAN模型中用作训练。
数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。这里面有50000张用于训练,构成了5个训练批,每一批10000 张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。 注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
下面这幅图就是列举了10各类,每一类展示了随机的10张图片:
这10类都是各自独立的,不会出现重叠。下载地址在这里。
拿到这样的数据集后,我们要如何把数据和标签整理好喂入GAN模型或者识别测试的程序中呢?这时候就需要我们对数据集做一定的处理了。 我们看看程序代码:
def load_cifar10(dataset_name):
data_dir = os.path.join("./data", dataset_name)
X_train = []
Y_train = []
dirname = './data/cifar/cifar-10-batches-py'
for i in range(1, 6):
fpath = os.path.join(dirname, 'data_batch_' + str(i))
data, labels = load_batch(fpath)
if i == 1:
X_train = data
Y_train = labels
else:
X_train = np.concatenate([X_train, data], axis=0)
Y_train = np.concatenate([Y_train, labels], axis=0)
fpath = os.path.join(dirname, 'test_batch')
X_test, Y_test = load_batch(fpath)
X_train = np.dstack((X_train[:, :1024], X_train[:, 1024:2048],
X_train[:, 2048:])) / 255.
X_train = np.reshape(X_train, [-1, 32, 32, 3])
X_test = np.dstack((X_test[:, :1024], X_test[:, 1024:2048],
X_test[:, 2048:])) / 255.
X_test = np.reshape(X_test, [-1, 32, 32, 3])
X = np.concatenate((X_train, X_test), axis=0)
# if one_hot:
# Y_train = to_categorical(Y_train, 10)
# Y_test = to_categorical(Y_test, 10)
Y_test = np.asarray(Y_test)
y = np.concatenate((Y_train, Y_test), axis=0).astype(np.int)
seed = 547
np.random.seed(seed) #确保每次生成的随机数相同
np.random.shuffle(X) #将mnist数据集中数据的位置打乱
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float)
#创建了(70000,10)的标签记录,并且根据cifar10已有标签记录相应的10维数组
for i, label in enumerate(y):
y_vec[i, y[i]] = 1.0
#返回归一化的数据和标签数组
return X, y_vec
程序的大体上是从下载好的cifar10数据集将6个10000张图片先进行合并,将训练集和测试集分开。对于GAN而言可以将训练集和测试集合并。 然后再将数据集整理,将cifar10图片转换为32x32的彩色3通道图片,将图片放入X中保存。对于标签我们将10类数据转换成one_hot形式, 最后存放在y_vec中,我们在实现过程中打乱了数据集的数据位置,这样可以保证训练更加的随机。我们返回两个变量,一个是数据图片X 和对应的one_hot形式的标签y_vec。
在实际的训练过程中,这个形式的数据可以像我们之前讲解的训练mnist数据集 的代码使用,但是网络结构要进行一定的变换。具体的代码由于出文章的原因我暂时就不公布了,等我忙清了,我再分享到github上。
下面是我们训练GAN得到的实验效果图:
cifar10的生成效果还是一般般的,原因也是在于cifar10难于训练的,期待更好的模型完成图像生成。
可选择性参考label的处理。
谢谢观看,希望对您有所帮助,欢迎指正错误,欢迎一起讨论!!!