熱文 | 卷積神經(jīng)網(wǎng)絡(luò)入門案例,輕松實(shí)現(xiàn)花朵分類(1)
前言
本文介紹卷積神經(jīng)網(wǎng)絡(luò)的入門案例,通過搭建和訓(xùn)練一個(gè)模型,來對幾種常見的花朵進(jìn)行識(shí)別分類;使用到TF的花朵數(shù)據(jù)集,它包含5類,即:“雛菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 張彩色圖片;通過搭建和訓(xùn)練卷積神經(jīng)網(wǎng)絡(luò)模型,對圖像進(jìn)行分類,能識(shí)別出圖像是“蒲公英”,或“玫瑰”,還是其它。
本篇文章主要的意義是帶大家熟悉卷積神經(jīng)網(wǎng)絡(luò)的開發(fā)流程,包括數(shù)據(jù)集處理、搭建模型、訓(xùn)練模型、使用模型等;更重要的是解在訓(xùn)練模型時(shí)遇到“過擬合”,如何解決這個(gè)問題,從而得到“泛化”更好的模型。
思路流程
導(dǎo)入數(shù)據(jù)集
探索集數(shù)據(jù),并進(jìn)行數(shù)據(jù)預(yù)處理
構(gòu)建模型(搭建神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)、編譯模型)
訓(xùn)練模型(把數(shù)據(jù)輸入模型、評(píng)估準(zhǔn)確性、作出預(yù)測、驗(yàn)證預(yù)測)
使用訓(xùn)練好的模型
優(yōu)化模型、重新構(gòu)建模型、訓(xùn)練模型、使用模型
目錄
導(dǎo)入數(shù)據(jù)集
探索集數(shù)據(jù),并進(jìn)行數(shù)據(jù)預(yù)處理
構(gòu)建模型
訓(xùn)練模型
使用模型
優(yōu)化模型、重新構(gòu)建模型、訓(xùn)練模型、使用模型(過擬合、數(shù)據(jù)增強(qiáng)、正則化、重新編譯和訓(xùn)練模型、預(yù)測新數(shù)據(jù))
導(dǎo)入數(shù)據(jù)集
使用到TF的花朵數(shù)據(jù)集,它包含5類,即:“雛菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 張彩色圖片;數(shù)據(jù)集包含5個(gè)子目錄,每個(gè)子目錄種存放一個(gè)類別的花朵圖片。
# 下載數(shù)據(jù)集 import pathlib dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True) data_dir = pathlib.Path(data_dir) # 查看數(shù)據(jù)集圖片的總數(shù)量 image_count = len(list(data_dir.glob('*/*.jpg'))) print(image_count)
探索集數(shù)據(jù),并進(jìn)行數(shù)據(jù)預(yù)處理
查看一張郁金香的圖片:
# 查看郁金香tulips目錄下的第1張圖片; tulips = list(data_dir.glob('tulips/*')) PIL.Image.open(str(tulips[0]))
加載數(shù)據(jù)集的圖片,使用keras.preprocessing從磁盤上加載這些圖像。
# 定義加載圖片的一些參數(shù),包括:批量大小、圖像高度、圖像寬度 batch_size = 32 img_height = 180 img_width = 180 # 將80%的圖像用于訓(xùn)練 train_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="training", seed=123, image_size=(img_height, img_width), batch_size=batch_size) # 將20%的圖像用于驗(yàn)證 val_ds = tf.keras.preprocessing.image_dataset_from_directory( data_dir, validation_split=0.2, subset="validation", seed=123, image_size=(img_height, img_width), batch_size=batch_size) # 打印數(shù)據(jù)集中花朵的類別名稱,字母順序?qū)?yīng)于目錄名稱 class_names = train_ds.class_names print(class_names)
查看一下訓(xùn)練數(shù)據(jù)集中的9張圖像
# 查看一下訓(xùn)練數(shù)據(jù)集中的9張圖像 import matplotlib.pyplot as plt plt.figure(figsize=(10, 10)) for images, labels in train_ds.take(1): for i in range(9): ax = plt.subplot(3, 3, i + 1) plt.imshow(images[i].numpy().astype("uint8")) plt.title(class_names[labels[i]]) plt.axis("off")
圖像形狀
傳遞這些數(shù)據(jù)集來訓(xùn)練模型model.fit,可以手動(dòng)遍歷數(shù)據(jù)集并檢索成批圖像:
for image_batch, labels_batch in train_ds: print(image_batch.shape) print(labels_batch.shape) break
能看到輸出:(32, 180, 180, 3) (32,)
image_batch是圖片形狀的張量(32, 180, 180, 3)。32是指批量大??;180,180分別表示圖像的高度、寬度,3是顏色通道RGB。32張圖片組成一個(gè)批次。
label_batch是形狀的張量(32,),對應(yīng)32張圖片的標(biāo)簽。
數(shù)據(jù)集預(yù)處理
下面進(jìn)行數(shù)據(jù)集預(yù)處理,將像素的值標(biāo)準(zhǔn)化至0到1的區(qū)間內(nèi):
# 將像素的值標(biāo)準(zhǔn)化至0到1的區(qū)間內(nèi)。 normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)
為什么是除以255呢?由于圖片的像素范圍是0~255,我們把它變成0~1的范圍,于是每張圖像(訓(xùn)練集、測試集)都除以255。
標(biāo)準(zhǔn)化數(shù)據(jù)
# 調(diào)用map將其應(yīng)用于數(shù)據(jù)集: normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) image_batch, labels_batch = next(iter(normalized_ds)) first_image = image_batch[0] # Notice the pixels values are now in `[0,1]`. print(np.min(first_image), np.max(first_image))
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。