In [1]:
# 导入TensorFlow中input_data.py文件
In [2]:
from tensorflow.examples.tutorials.mnist import input_data
In [3]:
# 从MNIST_data数据集中读取MNIST数据
In [4]:
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
In [5]:
# 进一步分析MNIST内容
In [6]:
# 加载数据
train_X = mnist.train.images #训练集样本
validation_X = mnist.validation.images #验证集样本
test_X = mnist.test.images #测试集样本
# 加载标签
train_Y = mnist.train.labels #训练集标签
validation_Y = mnist.validation.labels #验证集标签
test_Y = mnist.test.labels #测试集标签
In [7]:
print('训练集样本的大小:', train_X.shape)
print('训练集标签的大小:', train_Y.shape)
In [8]:
print('测试集样本的大小:', test_X.shape)
print('测试集标签的大小:', test_Y.shape)
In [9]:
print('验证集样本的大小:', validation_X.shape)
print('验证集标签的大小:', validation_Y.shape)
In [10]:
import matplotlib.pyplot as plt
In [11]:
# 显示出一张RGB图片看看
im = train_X[1]
im = im.reshape(-1, 28)
plt.imshow(im) # RGB图像
plt.show()
In [12]:
# 显示出一张灰度图片看看
im = train_X[1]
im = im.reshape(-1, 28)
plt.imshow(im,cmap='Greys')
plt.show()
In [13]:
#可视化样本,下面是输出了训练集中前20个样本
fig, ax = plt.subplots(nrows=4,ncols=5,sharex='all',sharey='all')
ax = ax.flatten()
for i in range(20):
img = train_X[i].reshape(28, 28)
ax[i].imshow(img,cmap='Greys')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
In [14]:
#查看数据,例如训练集中第一个样本的内容和标签
print(train_X[0]) #是一个包含784个元素且值在[0,1]之间的向量
print(train_Y[0])