当前位置:
首页 > temp > python入门教程 >
-
图像数据识别的模型
模型参数设置与模型构建及训练
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.callbacks import ModelCheckpoint
model = Sequential()
model.add(Dense(units=64, input_dim=100))
model.add(Activation("relu"))
model.add(Dense(units=64, input_dim=100))
model.add(Activation("softmax"))
#完成模型的搭建后,我们需要使用.compile()方法来编译模型:
model.compile(loss='categorical_croosentropy',metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)
loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128)
classes = model.predict(x_test, batch_size=128)
model.save('my_model.h5')
#更改loss函数和优化器
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy])
checkpointer = ModelCheckpoint(filepath="checkpoint-{epoch:02d}e-val_acc_{val_acc:2f}.hdf5"
,save_best_only=true, verbose=1, period=50)
model.fit(data,labels, epoch=10,batch_size=32, callbacks=[checkpointer])
#调用Checkpoint保存的model
model = load_model('checkpoint-05e-val_acc_0.58.hdf5')
#模型选取
from keras.application.vgg16 import VGG16
from keras.application.vgg19 import VGG19
from keras.application.inception_v3 import InceptionV3
from keras.application.resnet50 import ResNet50
model_vgg16_conv = VGG16(weights=None, include_top=False, pooling='avg')
output_vgg16_conv = model_vgg16_conv(input)
x = output_vgg16_conv
input = Input(shape=(width,height,channel),name='image_input')
x = Dense(clazz, activation='softmax', name='predictions')(x)
#Create your own model
model = Model(inputs=input, outputs=x)
model.complie(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adam(lr=lr,decay=0),metrics=['acc])
#load all Images
def LoadImageGen(files_data, labels_data,batch=32, label="label"):
start = 0
while start < len(file_data):
stop = start + batch
if stop > len(files_data):
stop = len(file_data)
imgs = []
labels = []
for i in range(start, stop):
imgs.append(LoadImage(file_data[i]))
labels.append(label_data[i])
yield(np.array(imgs),np.array(labels))
if start + batch < len(files_data):
start +=batch
else:
zip_data = list(zip(files_data,labels_data))
random.shuffle(zip_data)
files_data, labels_data = zip(*zip_data)
start=0
# load Images to training model
model.fit_generator(
LoadImageGen(train_x,train_y, batch=batch,label = "train"),
steps_per_epoch=int(len(train_x)/batch),
epochs = epoch,
verbose = 1,
validation_data = LoadImageGen(test_x,test_y, batch=batch,label = "test"),
validation_steps = int(len(test_x)/batch),
callbacks=[
EarlyStopping(monitor='val_acc',patience=patienceEpoch)),
modelCheckpoint
]
)
#模型可视化,Tensoborad
#采用keras特有的fit()进行指定callbacks函数即可,代码如下
from keras.callbacks import TensorBoard
from keras.models import Sequential
……
tbCallBack = keras.callbacks.TensorBoard(log_dir='tensorboard',
histogram_freq=1,write_graph=True,write_images=True)
model_history = model.fit(
X_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,valiation_data=(X_val,y_val),
callbacks = [EarlyStopping(patience=patience,model='min',verbose=1),tbCallBack])
#数据测试:对测试数据集进行验证,并输出测试结果
from keras.models import load_model
model = load_model(*./*.hdf5)
predictLabel=[]
for imgName in os.list(path_base):
img = LoadImage(os.path.join(path_base, imgName))
res = np.argmax(model.predict(np.array([img]))
predictLabel.append(LABELS[res])
acc = round(metrics.precision_score(trueLabel,predictLabel,average='macro'),4)
recall = round(metrics.recall_score(trueLabel,predictLabel,average='macro'),4)
f1_score = round(metrics.f1_score(trueLabel, predictLabel, average='macro'),4)
print("Test acc:{}, Test recall, Test F1_score:{}".format(acc_recall,f1_score))
VGG16:VGG(visual geometry group,超分辨率测试序列)
参考:https://zhuanlan.zhihu.com/p/41423739
共包含13卷积层(Convolutional Layer,表示为conv3-XXXX)+3个连接层(Fully connected Layer,表示为FC-XXXX)+5个池化层(Pool layer,表示maxpool),VGG16的16代表权重系数,maxpool没有权重系数,故16=13+3.
出处:https://www.cnblogs.com/TheFaceOfAutumnWhenSummerEnd/p/13880686.html
最新更新
nodejs爬虫
Python正则表达式完全指南
爬取豆瓣Top250图书数据
shp 地图文件批量添加字段
爬虫小试牛刀(爬取学校通知公告)
【python基础】函数-初识函数
【python基础】函数-返回值
HTTP请求:requests模块基础使用必知必会
Python初学者友好丨详解参数传递类型
如何有效管理爬虫流量?
2个场景实例讲解GaussDB(DWS)基表统计信息估
常用的 SQL Server 关键字及其含义
动手分析SQL Server中的事务中使用的锁
openGauss内核分析:SQL by pass & 经典执行
一招教你如何高效批量导入与更新数据
天天写SQL,这些神奇的特性你知道吗?
openGauss内核分析:执行计划生成
[IM002]Navicat ODBC驱动器管理器 未发现数据
初入Sql Server 之 存储过程的简单使用
SQL Server -- 解决存储过程传入参数作为s
关于JS定时器的整理
JS中使用Promise.all控制所有的异步请求都完
js中字符串的方法
import-local执行流程与node模块路径解析流程
检测数据类型的四种方法
js中数组的方法,32种方法
前端操作方法
数据类型
window.localStorage.setItem 和 localStorage.setIte
如何完美解决前端数字计算精度丢失与数