-
TensorFlow2.0(9):TensorBoard可视化
In [1]:
import tensorflow as tf
import tensorboard
import datetime
In [20]:
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
In [21]:
model = create_model()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 定义日志目录,必须是启动web应用时指定目录的子目录,建议使用日期时间作为子目录名
log_dir="/home/chb/jupyter/logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) # 定义TensorBoard对象
model.fit(x=x_train,
y=y_train,
epochs=5,
validation_data=(x_test, y_test),
callbacks=[tensorboard_callback]) # 将定义好的TensorBoard对象作为回调传给fit方法,这样就将TensorBoard嵌入了模型训练过程
Out[21]:
In [15]:
import datetime
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential ,metrics
In [16]:
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255.
y = tf.cast(y, dtype=tf.int32)
return x, y
In [17]:
(x, y), (x_test, y_test) = datasets.fashion_mnist.load_data()
db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(10000).batch(128)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.map(preprocess).batch(128)
In [18]:
model = Sequential([
layers.Dense(256, activation=tf.nn.relu), # [b, 784] --> [b, 256]
layers.Dense(128, activation=tf.nn.relu), # [b, 256] --> [b, 128]
layers.Dense(64, activation=tf.nn.relu), # [b, 128] --> [b, 64]
layers.Dense(32, activation=tf.nn.relu), # [b, 64] --> [b, 32]
layers.Dense(10) # [b, 32] --> [b, 10]
]
)
model.build(input_shape=[None,28*28])
model.summary()
optimizer = optimizers.Adam(lr=1e-3)#1e-3
# 指定日志目录
log_dir="/home/chb/jupyter/logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
summary_writer = tf.summary.create_file_writer(log_dir) # 创建日志文件句柄
In [19]:
db_iter = iter(db)
images = next(db_iter)
# 必须进行reshape,第一个纬度是图片数量或者说簇大小,28*28是图片大小,1是chanel,因为只灰度图片所以是1
images = tf.reshape(x, (-1, 28, 28, 1))
with summary_writer.as_default(): # 将第一个簇的图片写入TensorBoard
tf.summary.image('Training data', images, max_outputs=5, step=0) # max_outputs设置最大显示图片数量
In [20]:
tf.summary.trace_on(graph=True, profiler=True)
for epoch in range(30):
train_loss = 0
train_num = 0
for step, (x, y) in enumerate(db):
x = tf.reshape(x, [-1, 28*28])
with tf.GradientTape() as tape:
logits = model(x)
y_onehot = tf.one_hot(y,depth=10)
loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
loss_ce = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
loss_ce = tf.reduce_mean(loss_ce) # 计算整个簇的平均loss
grads = tape.gradient(loss_ce, model.trainable_variables) # 计算梯度
optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 更新梯度
train_loss += float(loss_ce)
train_num += x.shape[0]
loss = train_loss / train_num # 计算每一次迭代的平均loss
with summary_writer.as_default(): # 将loss写入TensorBoard
tf.summary.scalar('train_loss', train_loss, step=epoch)
total_correct = 0
total_num = 0
for x,y in db_test: # 用测试集验证每一次迭代后的准确率
x = tf.reshape(x, [-1, 28*28])
logits = model(x)
prob = tf.nn.softmax(logits, axis=1)
pred = tf.argmax(prob, axis=1)
pred = tf.cast(pred, dtype=tf.int32)
correct = tf.equal(pred, y)
correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))
total_correct += int(correct)
total_num += x.shape[0]
acc = total_correct / total_num # 平均准确率
with summary_writer.as_default(): # 将acc写入TensorBoard
tf.summary.scalar('test_acc', acc, step=epoch)
print(epoch, 'train_loss:',loss,'test_acc:', acc)
最新更新
博克-定制图例
博克-注释和图例
Bokeh–添加小部件
向博克图添加标签
将交互式滑块添加到博克图
在 Bokeh 中添加按钮
谷歌、微软、Meta?谁才是 Python 最大的金
Objective-C语法之代码块(block)的使用
URL Encode
go语言写http踩得坑
动手分析SQL Server中的事务中使用的锁
openGauss内核分析:SQL by pass & 经典执行
一招教你如何高效批量导入与更新数据
天天写SQL,这些神奇的特性你知道吗?
openGauss内核分析:执行计划生成
[IM002]Navicat ODBC驱动器管理器 未发现数据
初入Sql Server 之 存储过程的简单使用
SQL Server -- 解决存储过程传入参数作为s
[SQL Server]按照设定的周别的第一天算任意
Linux下定时自动备份Docker中所有SqlServer数
武装你的WEBAPI-OData入门
武装你的WEBAPI-OData便捷查询
武装你的WEBAPI-OData分页查询
武装你的WEBAPI-OData资源更新Delta
5. 武装你的WEBAPI-OData使用Endpoint 05-09
武装你的WEBAPI-OData之API版本管理
武装你的WEBAPI-OData常见问题
武装你的WEBAPI-OData聚合查询
OData WebAPI实践-OData与EDM
OData WebAPI实践-Non-EDM模式