換臉火了,我用 python 快速入門生成模型
引言:
近幾年來(lái),GAN生成對(duì)抗式應(yīng)用十分火熱,不論是抖音上大火的“螞蟻牙黑”還是B站上的“復(fù)原老舊照片”以及換臉等功能,都是基于GAN生成對(duì)抗式的模型。但是GAN算法對(duì)于大多數(shù)而言上手較難,故今天我們將使用最少的代碼,簡(jiǎn)單入門“生成對(duì)抗式網(wǎng)絡(luò)”,實(shí)現(xiàn)用GAN生成數(shù)字。
其中生成的圖片效果如下可見(jiàn):
01 模型建立
1.1 環(huán)境要求
本次環(huán)境使用的是python3.6.5+windows平臺(tái)
主要用的庫(kù)有:
OS模塊用來(lái)對(duì)本地文件讀寫(xiě)刪除、查找到等文件操作
numpy模塊用來(lái)矩陣和數(shù)據(jù)的運(yùn)算處理,其中也包括和深度學(xué)習(xí)框架之間的交互等
Keras模塊是一個(gè)由Python編寫(xiě)的開(kāi)源人工神經(jīng)網(wǎng)絡(luò)庫(kù),可以作為Tensorflow、Microsoft-CNTK和Theano的高階應(yīng)用程序接口,進(jìn)行深度學(xué)習(xí)模型的設(shè)計(jì)、調(diào)試、評(píng)估、應(yīng)用和可視化 。在這里我們用來(lái)搭建網(wǎng)絡(luò)層和直接讀取數(shù)據(jù)集操作,簡(jiǎn)單方便
Matplotlib模塊用來(lái)可視化訓(xùn)練效果等數(shù)據(jù)圖的制作
1.2 GAN簡(jiǎn)單介紹
GAN 由生成器 (Generator)和判別器 (Discriminator) 兩個(gè)網(wǎng)絡(luò)模型組成,這兩個(gè)模型作用并不相同,而是相互對(duì)抗。我們可以很簡(jiǎn)單的理解成,Generator是造假的的人,Discriminator是負(fù)責(zé)鑒寶的人。正是因?yàn)樯赡P秃蛯?duì)抗模型的相互對(duì)抗關(guān)系才稱之為生成對(duì)抗式。
那我們?yōu)槭裁床贿m用VAE去生成模型呢,又怎么知道GAN生成的圖片會(huì)比VAE生成的更優(yōu)呢?問(wèn)題就在于VAE模型作用是使得生成效果越相似越好,但事實(shí)上僅僅是相似卻只是依葫蘆畫(huà)瓢。而 GAN 是通過(guò) discriminator 來(lái)生成目標(biāo),而不是像 VAE線性般的學(xué)習(xí)。
這個(gè)項(xiàng)目里我們目標(biāo)是訓(xùn)練神經(jīng)網(wǎng)絡(luò)生成新的圖像,這些圖像與數(shù)據(jù)集中包含的圖像盡可能相近,而不是簡(jiǎn)單的復(fù)制粘貼。神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)什么是圖像的“本質(zhì)”,然后能夠從一個(gè)隨機(jī)的數(shù)字?jǐn)?shù)組開(kāi)始創(chuàng)建它。其主要思想是讓兩個(gè)獨(dú)立的神經(jīng)網(wǎng)絡(luò),一個(gè)產(chǎn)生器和一個(gè)鑒別器,相互競(jìng)爭(zhēng)。生成器會(huì)創(chuàng)建與數(shù)據(jù)集中的圖片盡可能相似的新圖像。判別器試圖了解它們是原始圖片還是合成圖片。
1.3 模型初始化
在這里我們初始化需要使用到的變量,以及優(yōu)化器、對(duì)抗式模型等。
def __init__(self, width=28, height=28, channels=1): self.width = width self.height = height self.channels = channels self.shape = (self.width, self.height, self.channels) self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8) self.G = self.__generator() self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer) self.D = self.__discriminator() self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) self.stacked_generator_discriminator = self.__stacked_generator_discriminator() self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)
1.4 生成器模型的搭建
這里我們盡可能簡(jiǎn)單的搭建一個(gè)生成器模型,3個(gè)完全連接的層,使用sequential標(biāo)準(zhǔn)化。神經(jīng)元數(shù)分別是256,512,1024等:
def __generator(self): """ Declare generator """ model = Sequential() model.add(Dense(256, input_shape=(100,))) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(self.width * self.height * self.channels, activation='tanh')) model.add(Reshape((self.width, self.height, self.channels))) return model
1.5 判別器模型的搭建
在這里同樣簡(jiǎn)單搭建判別器網(wǎng)絡(luò)層,和生成器模型類似:
def __discriminator(self): """ Declare discriminator """ model = Sequential() model.add(Flatten(input_shape=self.shape)) model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(np.int64((self.width * self.height * self.channels)/2))) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary() return model
1.6 對(duì)抗式模型的搭建
這里是較為難理解的部分。讓我們創(chuàng)建一個(gè)對(duì)抗性模型,簡(jiǎn)單來(lái)說(shuō)這只是一個(gè)后面跟著一個(gè)鑒別器的生成器。注意,在這里鑒別器的權(quán)重被凍結(jié)了,所以當(dāng)我們訓(xùn)練這個(gè)模型時(shí),生成器層將不受影響,只是向上傳遞梯度。代碼很簡(jiǎn)單如下:
def __stacked_generator_discriminator(self): self.D.trainable = False model = Sequential() model.add(self.G) model.add(self.D) return model
02 模型的訓(xùn)練使用
2.1 模型的訓(xùn)練
在這里,我們并沒(méi)有直接去訓(xùn)練生成器。而是通過(guò)對(duì)抗性模型間接地訓(xùn)練它。我們將噪聲傳遞給了對(duì)抗模型,并將所有從數(shù)據(jù)庫(kù)中獲取的圖像標(biāo)記為負(fù)標(biāo)簽,而它們將由生成器生成。
對(duì)真實(shí)圖像進(jìn)行預(yù)先訓(xùn)練的鑒別器把不能合成的圖像標(biāo)記為真實(shí)圖像,所犯的錯(cuò)誤將導(dǎo)致由損失函數(shù)計(jì)算出的損失越來(lái)越高。這就是反向傳播發(fā)揮作用的地方。由于鑒別器的參數(shù)是凍結(jié)的,在這種情況下,反向傳播不會(huì)影響它們。相反,它會(huì)影響生成器的參數(shù)。所以優(yōu)化對(duì)抗性模型的損失函數(shù)意味著使生成的圖像盡可能的相似,鑒別器將識(shí)別為真實(shí)的。這既是生成對(duì)抗式的神奇之處!
故訓(xùn)練階段結(jié)束時(shí),我們的目標(biāo)是對(duì)抗性模型的損失值很小,而鑒別器的誤差盡可能高,這意味著它不再能夠分辨出差異。
最終在我門的訓(xùn)練結(jié)束時(shí),鑒別器損失約為0.73??紤]到我們給它輸入了50%的真實(shí)圖像和50%的合成圖像,這意味著它有時(shí)無(wú)法識(shí)別假圖像。這是一個(gè)很好的結(jié)果,考慮到這個(gè)例子絕對(duì)不是優(yōu)化的結(jié)果。要知道確切的百分比,我可以在編譯時(shí)添加一個(gè)精度指標(biāo),這樣它可能得到很多更好的結(jié)果實(shí)現(xiàn)更復(fù)雜的結(jié)構(gòu)的生成器和判別器。
代碼如下,這里legit_images是指原始訓(xùn)練的圖像,而syntetic_images是生成的圖像。:
def train(self, X_train, epochs=20000, batch = 32, save_interval = 100): for cnt in range(epochs): ## train discriminator random_index = np.random.randint(0, len(X_train) - np.int64(batch/2)) legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels) gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100)) syntetic_images = self.G.predict(gen_noise) x_combined_batch = np.concatenate((legit_images, syntetic_images)) y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1)))) d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch) # train generator noise = np.random.normal(0, 1, (batch, 100)) y_mislabled = np.ones((batch, 1)) g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled) print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss)) if cnt % save_interval == 0: self.plot_images(save2file=True, step=cnt)
2.2 可視化
使用matplotlib來(lái)可視化模型訓(xùn)練效果。
def plot_images(self, save2file=False, samples=16, step=0): ''' Plot and generated images ''' if not os.path.exists("./images"): os.makedirs("./images") filename = "./images/mnist_%d.png" % step noise = np.random.normal(0, 1, (samples, 100)) images = self.G.predict(noise) plt.figure(figsize=(10, 10)) for i in range(images.shape[0]): plt.subplot(4, 4, i+1) image = images[i, :, :, :] image = np.reshape(image, [self.height, self.width]) plt.imshow(image, cmap='gray') plt.axis('off') plt.tight_layout() if save2file: plt.savefig(filename) plt.close('all') else: plt.show()
03 使用方法
考慮到代碼較少,下述代碼復(fù)制粘貼即可運(yùn)行。
# -*- coding: utf-8 -*- import os import numpy as np from IPython.core.debugger import Tracer from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization from keras.layers.advanced_activations import LeakyReLU from keras.models import Sequential from keras.optimizers import Adam import matplotlib.pyplot as plt plt.switch_backend('agg') # allows code to run without a system DISPLAY class GAN(object): """ Generative Adversarial Network class """ def __init__(self, width=28, height=28, channels=1): self.width = width self.height = height self.channels = channels self.shape = (self.width, self.height, self.channels) self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8) self.G = self.__generator() self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer) self.D = self.__discriminator() self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy']) self.stacked_generator_discriminator = self.__stacked_generator_discriminator() self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer) def __generator(self): """ Declare generator """ model = Sequential() model.add(Dense(256, input_shape=(100,))) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(self.width * self.height * self.channels, activation='tanh')) model.add(Reshape((self.width, self.height, self.channels))) return model def __discriminator(self): """ Declare discriminator """ model = Sequential() model.add(Flatten(input_shape=self.shape)) model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(np.int64((self.width * self.height * self.channels)/2))) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary() return model def __stacked_generator_discriminator(self): self.D.trainable = False model = Sequential() model.add(self.G) model.add(self.D) return model def train(self, X_train, epochs=20000, batch = 32, save_interval = 100): for cnt in range(epochs): ## train discriminator random_index = np.random.randint(0, len(X_train) - np.int64(batch/2)) legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels) gen_noise = np.random.normal(0, 1, (np.int64(batch/2), 100)) syntetic_images = self.G.predict(gen_noise) x_combined_batch = np.concatenate((legit_images, syntetic_images)) y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1)))) d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch) # train generator noise = np.random.normal(0, 1, (batch, 100)) y_mislabled = np.ones((batch, 1)) g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled) print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss)) if cnt % save_interval == 0: self.plot_images(save2file=True, step=cnt) def plot_images(self, save2file=False, samples=16, step=0): ''' Plot and generated images ''' if not os.path.exists("./images"): os.makedirs("./images") filename = "./images/mnist_%d.png" % step noise = np.random.normal(0, 1, (samples, 100)) images = self.G.predict(noise) plt.figure(figsize=(10, 10)) for i in range(images.shape[0]): plt.subplot(4, 4, i+1) image = images[i, :, :, :] image = np.reshape(image, [self.height, self.width]) plt.imshow(image, cmap='gray') plt.axis('off') plt.tight_layout() if save2file: plt.savefig(filename) plt.close('all') else: plt.show() if __name__ == '__main__': (X_train, _), (_, _) = mnist.load_data() # Rescale -1 to 1 X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3) gan = GAN() gan.train(X_train)
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。