VB.net 2010 视频教程 VB.net 2010 视频教程 python基础视频教程
SQL Server 2008 视频教程 c#入门经典教程 Visual Basic从门到精通视频教程
当前位置:
首页 > Python基础教程 >
  • 在C#下使用TensorFlow.NET训练自己的数据集(2)

[] { num_filters }); // tf.summary.histogram("bias", b); var layer = tf.nn.conv2d(x, W, strides: new[] { 1, stride, stride, 1 }, padding: "SAME"); layer += b; return tf.nn.relu(layer); }); } /// <summary> /// Create a max pooling layer /// </summary> /// <param name="x">input to max-pooling layer</param> /// <param name="ksize">size of the max-pooling filter</param> /// <param name="stride">stride of the max-pooling filter</param> /// <param name="name">layer name</param> /// <returns>The output array</returns> private Tensor max_pool(Tensor x, int ksize, int stride, string name) { return tf.nn.max_pool(x, ksize: new[] { 1, ksize, ksize, 1 }, strides: new[] { 1, stride, stride, 1 }, padding: "SAME", name: name); } /// <summary> /// Flattens the output of the convolutional layer to be fed into fully-connected layer /// </summary> /// <param name="layer">input array</param> /// <returns>flattened array</returns> private Tensor flatten_layer(Tensor layer) { return tf_with(tf.variable_scope("Flatten_layer"), delegate { var layer_shape = layer.TensorShape; var num_features = layer_shape[new Slice(1, 4)].size; var layer_flat = tf.reshape(layer, new[] { -1, num_features }); return layer_flat; }); } /// <summary> /// Create a weight variable with appropriate initialization /// </summary> /// <param name="name"></param> /// <param name="shape"></param> /// <returns></returns> private RefVariable weight_variable(string name, int[] shape) { var initer = tf.truncated_normal_initializer(stddev: 0.01f); return tf.get_variable(name, dtype: tf.float32, shape: shape, initializer: initer); } /// <summary> /// Create a bias variable with appropriate initialization /// </summary> /// <param name="name"></param> /// <param name="shape"></param> /// <returns></returns> private RefVariable bias_variable(string name, int[] shape) { var initial = tf.constant(0f, shape: shape, dtype: tf.float32); return tf.get_variable(name, dtype: tf.float32, initializer: initial); } /// <summary> /// Create a fully-connected layer /// </summary> /// <param name="x">input from previous layer</param> /// <param name="num_units">number of hidden units in the fully-connected layer</param> /// <param name="name">layer name</param> /// <param name="use_relu">boolean to add ReLU non-linearity (or not)</param> /// <returns>The output array</returns> private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) { return tf_with(tf.variable_scope(name), delegate { var in_dim = x.shape[1]; var W = weight_variable("W_" + name, shape: new[] { in_dim, num_units }); var b = bias_variable("b_" + name, new[] { num_units }); var layer = tf.matmul(x, W) + b; if (use_relu) layer = tf.nn.relu(layer); return layer; }); } #endregion
复制代码

 

 

模型训练和模型保存

  • Batch数据集的读取,采用了 SharpCV 的cv2.imread,可以直接读取本地图像文件至NDArray,实现CV和Numpy的无缝对接

  • 使用.NET的异步线程安全队列BlockingCollection<T>,实现TensorFlow原生的队列管理器FIFOQueue

    • 在训练模型的时候,我们需要将样本从硬盘读取到内存之后,才能进行训练。我们在会话中运行多个线程,并加入队列管理器进行线程间的文件入队出队操作,并限制队列容量,主线程可以利用队列中的数据进行训练,另一个线程进行本地文件的IO读取,这样可以实现数据的读取和模型的训练是异步的,降低训练时间。

  • 模型的保存,可以选择每轮训练都保存,或最佳训练模型保存

    复制代码
    #region Train
    public void Train(Session sess)
    {
        // Number of training iterations in each epoch
        var num_tr_iter = (ArrayLabel_Train.Length) / batch_size;
    
        var init = tf.global_variables_initializer();
        sess.run(init);
    
        var saver = tf.train.Saver(tf.global_variables(), max_to_keep: 10);
    
        path_model = Name + "\\MODEL";
        Directory.CreateDirectory(path_model);
    
        float loss_val = 100.0f;
        float accuracy_val = 0f;
    
        var sw = new Stopwatch();
        sw.Start();
        foreach (var epoch in range(epochs))
        {
            print($"Training epoch: {epoch + 1}");
            // Randomly shuffle the training data at the beginning of each epoch 
            (ArrayFileName_Train, ArrayLabel_Train) = ShuffleArray(ArrayLabel_Train.Length, ArrayFileName_Train, ArrayLabel_Train);
            y_train = np.eye(Dict_Label.Count)[new NDArray(ArrayLabel_Train)];
    
            //decay learning rate
            if (learning_rate_step != 0)
            {
                if ((epoch != 0) && (epoch % learning_rate_step == 0))
                {
                    learning_rate_base = learning_rate_base * learning_rate_decay;
                    if (learning_rate_base <= learning_rate_min) { learning_rate_base = learning_rate_min; }
                    sess.run(tf.assign(learning_rate, learning_rate_base));
                }
            }
    
            //Load local images asynchronously,use queue,improve train efficiency
            BlockingCollection<(NDArray c_x, NDArray c_y, int iter)> BlockC = new BlockingCollection<(NDArray C1, NDArray C2, int iter)>(TrainQueueCapa);
            Task.Run(() =>
                     {
                         foreach (var iteration in range(num_tr_iter))
                         {
                             var start = iteration * batch_size;
                             var end = (iteration + 1) * batch_size;
                             (NDArray x_batch, NDArray y_batch) = GetNextBatch(sess, ArrayFileName_Train, y_train, start, end);
                             BlockC.Add((x_batch, y_batch, iteration));
                         }
                         BlockC.CompleteAdding();
                     });
    
            foreach (var item in BlockC.GetConsumingEnumerable())
            {
                sess.run(optimizer, (x, item.c_x), (y, item.c_y));
    
                if (item.iter % display_freq == 0)
                {
                    // Calculate and display the batch loss and accuracy
                    var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, item.c_x), new FeedItem(y, item.c_y));
                    loss_val = result[0];
                    accuracy_val = result[1];
                    print("CNN:" + ($"iter {item.iter.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"));
                    sw.Restart();
                }
            }             
    
            // Run validation after every epoch
            (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid));
            print("CNN:" + "---------------------------------------------------------");
            print("CNN:" + $"gloabl steps: {sess.run(gloabl_steps) },learning rate: {sess.run(learning_rate)}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
            print("CNN:" + "---------------------------------------------------------");
    
            if (SaverBest)
            {
                if (accuracy_val > max_accuracy)
                {
                    max_accuracy = accuracy_val;
                    saver.save(sess, path_model + "\\CNN_Best");
                    print("CKPT Model is save.");
                }
            }
            else
            {
                saver.save(sess, path_model + string.Format("\\CNN_Epoch_{0}_Loss_{1}_Acc_{2}", epoch, loss_val, accuracy_val));
                print("CKPT Model is save.");
            }
        }
        Write_Dictionary(path_model + "\\dic.txt", Dict_Label);
    }
    private void Write_Dictionary(string path, Dictionary<Int64, string> mydic)
    {
        FileStream fs = new FileStream(path, FileMode.Create);
        StreamWriter sw = new StreamWriter(fs);
        foreach (var d in mydic) { sw.Write(d.Key + "," + d.Value + "\r\n"); }
        sw.Flush();
        sw.Close();
        fs.Close();
        print("Write_Dictionary");
    }
    private (NDArray, NDArray) Randomize(NDArray x, NDArray y)
    {
        var perm = np.random.permutation(y.shape[0]);
        np.random.shuffle(perm);
        return (x[perm], y[perm]);
    }
    private (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end)
    {
        var slice = new Slice(start, end);
        var x_batch = x[slice];
        var y_batch = y[slice];
        return (x_batch, y_batch);
    }
    private unsafe (NDArray, NDArray) GetNextBatch(Session sess, string[] x, NDArray y, int start, int end)
    {
        NDArray x_batch = np.zeros(end - start, img_h, img_w, n_channels);
        int n = 0;
        for (int i = start; i < end; i++)
        {
          NDArray img4 = cv2.imread(x[i], IMREAD_COLOR.IMREAD_GRAYSCALE);
            x_batch[n] = sess.run(normalized, (decodeJpeg, img4));
            n++;
        }
        var slice = new Slice(start, end);
        var y_batch = y[slice];
        return (x_batch, y_batch);
    }
    #endregion   
    复制代码

     

     

测试集预测

  • 训练完成的模型对test数据集进行预测,并统计准确率

  • 计算图中增加了一个提取预测结果Top-1的概率的节点,最后测试集预测的时候可以把详细的预测数据进行输出,方便实际工程中进行调试和优化。

    复制代码
    public void Test(Session sess)
    {
        (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test));
        print("CNN:" + "---------------------------------------------------------");
        print("CNN:" + $"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
        print("CNN:" + "---------------------------------------------------------");
    
        (Test_Cls, Test_Data) = sess.run((cls_prediction, prob), (x, x_test));
    
    }
    private void TestDataOutput()
    {
        for (int i = 0; i < ArrayLabel_Test.Length; i++)
        {
            Int64 real = ArrayLabel_Test[i];
            int predict = (int)(Test_Cls[i]);
            var probability = Test_Data[i, predict];
            string result = (real == predict) ? "OK" : "NG";
            string fileName = ArrayFileName_Test[i];
            string real_str = Dict_Label[real];
            string predict_str = Dict_Label[predict];
            print((i + 1).ToString() + "|" + "result:" + result + "|" + "real_str:" + real_str + "|"
                  + "predict_str:" + predict_str + "|" + "probability:" + probability.GetSingle().ToString() + "|"
                  + "fileName:" + fileName);
        }
    }
    复制代码

     

     

总结

本文主要是.NET下的TensorFlow在实际工业现场视觉检测项目中的应用,使用SciSharp的TensorFlow.NET构建了简单的CNN图像分类模型,该模型包含输入层、卷积与池化层、扁平化层、全连接层和输出层,这些层都是CNN分类模型的必要的层,针对工业现场的实际图像进行了分类,分类准确性较高。

完整代码可以直接用于大家自己的数据集进行训练,已经在工业现场经过大量测试,可以在GPU或CPU环境下运行,只需要更换tensorflow.dll文件即可实现训练环境的切换。

同时,训练完成的模型文件,可以使用 “CKPT+Meta” 或 冻结成“PB” 2种方式,进行现场的部署,模型部署和现场应用推理可以全部在.NET平台下进行,实现工业现场程序的无缝对接。摆脱了以往Python下 需要通过Flask搭建服务器进行数据通讯交互 的方式,现场部署应用时无需配置Python和TensorFlow的环境【无需对工业现场的原有PC升级安装一大堆环境】,整个过程全部使用传统的.NET的DLL引用的方式。


相关教程
关于我们--广告服务--免责声明--本站帮助-友情链接--版权声明--联系我们       黑ICP备07002182号