首页 > Python基础教程 >
-
在C#下使用TensorFlow.NET训练自己的数据集(2)
模型训练和模型保存
-
Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接;
-
使用.NET的异步线程安全队列BlockingCollection<T>,实现TensorFlow原生的队列管理器FIFOQueue;
-
在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。
-
-
模型的保存,可以选择每轮训练都保存,或最佳训练模型保存
#region Train public void Train(Session sess) { // Number of training iterations in each epoch var num_tr_iter = (ArrayLabel_Train.Length) / batch_size; var init = tf.global_variables_initializer(); sess.run(init); var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10); path_model = Name + "\\MODEL"; Directory.CreateDirectory(path_model); float loss_val = 100.0f; float accuracy_val = 0f; var sw = new Stopwatch(); sw.Start(); foreach (var epoch in range(epochs)) { print($"Training epoch: {epoch + 1}"); // Randomly shuffle the training data at the beginning of each epoch (ArrayFileName_Train, ArrayLabel_Train) = ShuffleArray(ArrayLabel_Train.Length, ArrayFileName_Train, ArrayLabel_Train); y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)]; //decay learning rate if (learning_rate_step != 0) { if ((epoch != 0) && (epoch % learning_rate_step == 0)) { learning_rate_base = learning_rate_base * learning_rate_decay; if (learning_rate_base <= learning_rate_min) { learning_rate_base = learning_rate_min; } sess.run(tf.assign(learning_rate, learning_rate_base)); } } //Load local images asynchronously,use queue,improve train efficiency BlockingCollection<(NDArray c_x, NDArray c_y, int iter)> BlockC = new BlockingCollection<(NDArray C1, NDArray C2, int iter)>(TrainQueueCapa); Task.Run(() => { foreach (var iteration in range(num_tr_iter)) { var start = iteration * batch_size; var end = (iteration + 1) * batch_size; (NDArray x_batch, NDArray y_batch) = GetNextBatch(sess, ArrayFileName_Train, y_train, start, end); BlockC.Add((x_batch, y_batch, iteration)); } BlockC.CompleteAdding(); }); foreach (var item in BlockC.GetConsumingEnumerable()) { sess.run(optimizer, (x, item.c_x), (y, item.c_y)); if (item.iter % display_freq == 0) { // Calculate and display the batch loss and accuracy var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y)); loss_val = result[0]; accuracy_val = result[1]; print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms")); sw.Restart(); } } // Run validation after every epoch (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid)); print("CNN:" + "---------------------------------------------------------"); print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); print("CNN:" + "---------------------------------------------------------"); if (SaverBest) { if (accuracy_val > max_accuracy) { max_accuracy = accuracy_val; saver.save(sess, path_model + "\\CNN_Best"); print("CKPT Model is save."); } } else { saver.save(sess, path_model + string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val)); print("CKPT Model is save."); } } Write_Dictionary(path_model + "\\dic.txt", Dict_Label); } private void Write_Dictionary(string path, Dictionary<Int64, string> mydic) { FileStream fs = new FileStream(path, FileMode.Create); StreamWriter sw = new StreamWriter(fs); foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "\r\n"); } sw.Flush(); sw.Close(); fs.Close(); print("Write_Dictionary"); } private (NDArray, NDArray) Randomize(NDArray x, NDArray y) { var perm = np.random.permutation(y.shape[0]); np.random.shuffle(perm); return (x[perm], y[perm]); } private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) { var slice = new Slice(start, end); var x_batch = x[slice]; var y_batch = y[slice]; return (x_batch, y_batch); } private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end) { NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels); int n = 0; for (int i = start; i < end; i++) { NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE); x_batch[n] = sess.run(normalized, (decodeJpeg, img4)); n++; } var slice = new Slice(start, end); var y_batch = y[slice]; return (x_batch, y_batch); } #endregion
测试集预测
-
训练完成的模型对test数据集进行预测,并统计准确率
-
计算图中增加了一个提取预测结果Top-1的概率的节点,最后测试集预测的时候可以把详细的预测数据进行输出,方便实际工程中进行调试和优化。
public void Test(Session sess) { (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test)); print("CNN:" + "---------------------------------------------------------"); print("CNN:" + $"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("CNN:" + "---------------------------------------------------------"); (Test_Cls, Test_Data) = sess.run((cls_prediction, prob), (x, x_test)); } private void TestDataOutput() { for (int i = 0; i < ArrayLabel_Test.Length; i++) { Int64 real = ArrayLabel_Test[i]; int predict = (int)(Test_Cls[i]); var probability = Test_Data[i, predict]; string result = (real == predict) ? "OK" : "NG"; string fileName = ArrayFileName_Test[i]; string real_str = Dict_Label[real]; string predict_str = Dict_Label[predict]; print((i + 1).ToString() + "|" + "result:" + result + "|" + "real_str:" + real_str + "|" + "predict_str:" + predict_str + "|" + "probability:" + probability.GetSingle().ToString() + "|" + "fileName:" + fileName); } }
总结
本文主要是.NET下的TensorFlow在实际工业现场视觉检测项目中的应用,使用SciSharp的TensorFlow.NET构建了简单的CNN图像分类模型,该模型包含输入层、卷积与池化层、扁平化层、全连接层和输出层,这些层都是CNN分类模型的必要的层,针对工业现场的实际图像进行了分类,分类准确性较高。
完整代码可以直接用于大家自己的数据集进行训练,已经在工业现场经过大量测试,可以在GPU或CPU环境下运行,只需要更换tensorflow.dll文件即可实现训练环境的切换。
同时,训练完成的模型文件,可以使用 “CKPT+Meta” 或 冻结成“PB” 2种方式,进行现场的部署,模型部署和现场应用推理可以全部在.NET平台下进行,实现工业现场程序的无缝对接。摆脱了以往Python下 需要通过Flask搭建服务器进行数据通讯交互 的方式,现场部署应用时无需配置Python和TensorFlow的环境【无需对工业现场的原有PC升级安装一大堆环境】,整个过程全部使用传统的.NET的DLL引用的方式。