Генеративно-состязательная сеть, которую вы построите, создаёт персонажей из манги и аниме. Рисуйте вайфу в своё удовольствие!
Давно хотели создать своих Аску, Код 002 или Канеки Кена? У вас появилась отличная возможность это сделать :)
Что такое генеративно-состязательная сеть?
Лучший вывод, который может генерировать нейронная сеть, похож на человеческий. Образно генеративно-состязательная сеть (GAN) может даже обмануть человека, заставив его думать, что вывод сделан им самим.
В генеративно-состязательных сетях две сети соревнуются друг с другом, что приводит к взаимным импровизациям. Генератор обманывает дискриминатор, создавая ложные входы и выдавая их за реальные. Дискриминатор сообщает, является ли ввод реальным или ложным.
В обучении GAN есть три главных шага:
-
- Используйте генератор для создания ложных входов из случайного шума.
- Обучите дискриминатор на ложных и реальных входах (одновременно с объединением или поочерёдно, что предпочтительнее).
- Обучите всю модель: дискриминатор + генератор.
Помните, что весовые коэффициенты дискриминатора «заморожены» во время последнего шага.
Причина сочетания обоих сетей состоит в отсутствии обратной связи на выходах генератора. Единственный ориентир – если дискриминатор принимает выходы генератора.
Можно сказать, что они соперничают друг с другом. Генератор обучается во время схватки с «соперником», чтобы реализовать цель.
Наша сеть
Для задачи мы используем глубокую сверточную генеративно-состязательную сеть.
Вот несколько рекомендаций для таких сетей:
- Замените максимальные подвыборки шагами свёртки.
- Используйте перемещённую свёртку для повышения частоты дискретизации.
- Устраните полностью соединённые слои.
- Используйте пакетную нормализацию, кроме выходного слоя генератора и входного слоя дискриминатора.
- Используйте ReLU в генераторе, кроме выхода, который использует tanh.
- Leaky ReLU в дискриминаторе.
Детали сетапа
- версия Keras==2.2.4
- TensorFlow==1.8.0
- Jupyter Notebook
- Matplotlib и другие библиотеки типа NumPy, Pandas
- Python==3.5.7
Набор данных
Набор данных для лиц из аниме можно собрать на тематических сайтах, скачивая картинки и вырезая лица. Код Python, который демонстрирует это.
Также доступны обработанные и обрезанные лица.
Генератор
Он состоит из сверточных слоёв, пакетной нормализации и функции активации Leaky ReLU для повышения частоты дискретизации. Мы используем параметр шагов в сверточном слое, чтобы избежать нестабильной обучаемости GAN. Функция не будет равно нулю, если x < 0, вместо этого Leaky ReLU имеет небольшое отрицательное отклонение (0.01 и так далее).
Код
def get_gen_normal(noise_shape): kernel_init = 'glorot_uniform' gen_input = Input(shape = noise_shape) generator = Conv2DTranspose(filters = 512, kernel_size = (4,4), strides = (1,1), padding = "valid", data_format = "channels_last", kernel_initializer = kernel_init)(gen_input) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2D(filters = 64, kernel_size = (3,3), strides = (1,1), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = BatchNormalization(momentum = 0.5)(generator) generator = LeakyReLU(0.2)(generator) generator = Conv2DTranspose(filters = 3, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator) generator = Activation('tanh')(generator) gen_opt = Adam(lr=0.00015, beta_1=0.5) generator_model = Model(input = gen_input, output = generator) generator_model.compile(loss='binary_crossentropy', optimizer=gen_opt, metrics=['accuracy']) generator_model.summary() return generator_model
Дискриминатор
Он также состоит из сверточных слоёв, где мы используем шаги для понижения частоты дискретизации и пакетной нормализации для стабильности.
Код
def get_disc_normal(image_shape=(64,64,3)): dropout_prob = 0.4 kernel_init = 'glorot_uniform' dis_input = Input(shape = image_shape) discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator) discriminator = BatchNormalization(momentum = 0.5)(discriminator) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator) discriminator = BatchNormalization(momentum = 0.5)(discriminator) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator) discriminator = BatchNormalization(momentum = 0.5)(discriminator) discriminator = LeakyReLU(0.2)(discriminator) discriminator = Flatten()(discriminator) discriminator = Dense(1)(discriminator) discriminator = Activation('sigmoid')(discriminator) dis_opt = Adam(lr=0.0002, beta_1=0.5) discriminator_model = Model(input = dis_input, output = discriminator) discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy']) discriminator_model.summary() return discriminator_model
Полная GAN
Чтобы генератор проверял выход, скомпилируем сеть в Keras. В этой сети входом будет случайный шум для генератора, а выходом – выход генератора «скормленный» дискриминатору. Во избежание «коллапса состязания» весовые коэффициенты дискриминатора заморожены.
Код
discriminator.trainable = False opt = Adam(lr=0.00015, beta_1=0.5) #same as generator gen_inp = Input(shape=noise_shape) GAN_inp = generator(gen_inp) GAN_opt = discriminator(GAN_inp) gan = Model(input = gen_inp, output = GAN_opt) gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics=['accuracy']) gan.summary()
Тренировка модели
Базовая конфигурация модели
1. Сгенерируйте случайный нормальный шум для входа:
def gen_noise(batch_size, noise_shape): return np.random.normal(0, 1, size=(batch_size,)+noise_shape)
2. Объедините реальные данные из набора с шумом:
data_X = np.concatenate([real_data_X, fake_data_X]) real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2 fake_data_Y = np.random.random_sample(batch_size)*0.2 data_Y = np.concatenate([real_data_Y, fake_data_Y])
3. Подайте шум на вход:
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2 fake_data_Y = np.random.random_sample(batch_size)*0.2
4. Тренируем только генератор
print("Begin step: ", tot_step) step_begin_time = time.time() real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir) noise = gen_noise(batch_size,noise_shape) fake_data_X = generator.predict(noise) if (tot_step % 100) == 0: step_num = str(tot_step).zfill(4) save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")
5. Тренируем только дискриминатор:
discriminator.trainable = True generator.trainable = False dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y) dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y) print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))
6. Тренируем совмещённую GAN:
generator.trainable = True discriminator.trainable = False GAN_X = gen_noise(batch_size,noise_shape) GAN_Y = real_data_Y gan_metrics = gan.train_on_batch(GAN_X,GAN_Y) print("GAN loss: %f" % (gan_metrics[0])) text_file = open(log_dir+"\\training_log.txt", "a") text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0])) text_file.close() avg_GAN_loss.append(gan_metrics[0]) end_time = time.time() diff_time = int(end_time - step_begin_time) print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))
7. Сохраните экземпляры дискриминатора и генератора:
if ((tot_step+1) % 500) == 0: print("-----------------------------------------------------------------") print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss))) print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss))) print("Average GAN loss: %f" % (np.mean(avg_GAN_loss))) print("-----------------------------------------------------------------") discriminator.trainable = False generator.trainable = False # predict on fixed_noise fixed_noise_generate = generator.predict(noise) step_num = str(tot_step).zfill(4) save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png") generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5") discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")
Результаты Манга-генератора
После 10000 шагов обучения результат выглядит круто! Смотрите сами.
Более длительная тренировка с большим набором данных приведёт к лучшим результатам. (Некоторые лица получились страшными, это правда :D)
Заключение
Наверняка задача генерации лиц в стиле аниме интересна.
Но в ней ещё есть место для улучшений: лучшее обучение, модели и набор данных. Наша модель вряд ли заставит человека задуматься, реальны ли сгенерированные модели. Тем не менее, она отлично справляется с поставленной задачей. Вы можете улучшить сеть, используя персонажей во весь рост.
Исходный код проекта доступен на GitHub.
Комментарии