Пожаловаться

Аниме и генеративно-состязательная сеть: в чём связь?

5020
Пожаловаться

Генеративно-состязательная сеть, которую вы построите, создаёт персонажей из манги и аниме. Рисуйте вайфу в своё удовольствие!

Аниме и генеративно-состязательная сеть: в чём связь?

Давно хотели создать своих Аску, Код 002 или Канеки Кена? У вас появилась отличная возможность это сделать :)

Что такое генеративно-состязательная сеть?

Лучший вывод, который может генерировать нейронная сеть, похож на человеческий. Образно генеративно-состязательная сеть (GAN) может даже обмануть человека, заставив его думать, что вывод сделан им самим.

В генеративно-состязательных сетях две сети соревнуются друг с другом, что приводит к взаимным импровизациям. Генератор обманывает дискриминатор, создавая ложные входы и выдавая их за реальные. Дискриминатор сообщает, является ли ввод реальным или ложным.

Аниме и генеративно-состязательная сеть: в чём связь?

В обучении GAN есть три главных шага:

    1. Используйте генератор для создания ложных входов из случайного шума.
    2. Обучите дискриминатор на ложных и реальных входах (одновременно с объединением или поочерёдно, что предпочтительнее).
    3. Обучите всю модель: дискриминатор + генератор.

Помните, что весовые коэффициенты дискриминатора «заморожены» во время последнего шага.

Причина сочетания обоих сетей состоит в отсутствии обратной связи на выходах генератора. Единственный ориентир – если дискриминатор принимает выходы генератора.

Генеративно-состязательная сеть

Можно сказать, что они соперничают друг с другом. Генератор обучается во время схватки с «соперником», чтобы реализовать цель.

Наша сеть

Для задачи мы используем глубокую сверточную генеративно-состязательную сеть.

Вот несколько рекомендаций для таких сетей:

  1. Замените максимальные подвыборки шагами свёртки.
  2. Используйте перемещённую свёртку для повышения частоты дискретизации.
  3. Устраните полностью соединённые слои.
  4. Используйте пакетную нормализацию, кроме выходного слоя генератора и входного слоя дискриминатора.
  5. Используйте ReLU в генераторе, кроме выхода, который использует tanh.
  6. Leaky ReLU в дискриминаторе.

Детали сетапа

  • версия Keras==2.2.4
  • TensorFlow==1.8.0
  • Jupyter Notebook
  • Matplotlib и другие библиотеки типа NumPy, Pandas
  • Python==3.5.7

Набор данных

Набор данных для лиц из аниме можно собрать на тематических сайтах, скачивая картинки и вырезая лица. Код Python, который демонстрирует это.

Также доступны обработанные и обрезанные лица.

Аниме и генеративно-состязательная сеть: в чём связь?

Генератор

Он состоит из сверточных слоёв, пакетной нормализации и функции активации Leaky ReLU для повышения частоты дискретизации. Мы используем параметр шагов в сверточном слое, чтобы избежать нестабильной обучаемости GAN. Функция не будет равно нулю, если x < 0, вместо этого Leaky ReLU имеет небольшое отрицательное отклонение (0.01 и так далее).

Аниме и генеративно-состязательная сеть: в чём связь?

Код

def get_gen_normal(noise_shape):
    kernel_init = 'glorot_uniform'    
    gen_input = Input(shape = noise_shape) 
    
    generator = Conv2DTranspose(filters = 512, kernel_size = (4,4), strides = (1,1), padding = "valid", data_format = "channels_last", kernel_initializer = kernel_init)(gen_input)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2D(filters = 64, kernel_size = (3,3), strides = (1,1), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 3, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    
    generator = Activation('tanh')(generator)
    
    gen_opt = Adam(lr=0.00015, beta_1=0.5)
    generator_model = Model(input = gen_input, output = generator)
    generator_model.compile(loss='binary_crossentropy', optimizer=gen_opt, metrics=['accuracy'])
    generator_model.summary()

    return generator_model

Дискриминатор

Он также состоит из сверточных слоёв, где мы используем шаги для понижения частоты дискретизации и пакетной нормализации для стабильности.

Аниме и генеративно-состязательная сеть: в чём связь?

Код

def get_disc_normal(image_shape=(64,64,3)):
    dropout_prob = 0.4
    kernel_init = 'glorot_uniform'
    dis_input = Input(shape = image_shape)
    
    discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
    discriminator = BatchNormalization(momentum = 0.5)(discriminator)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
    discriminator = BatchNormalization(momentum = 0.5)(discriminator)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
    discriminator = BatchNormalization(momentum = 0.5)(discriminator)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Flatten()(discriminator)

    discriminator = Dense(1)(discriminator)

    discriminator = Activation('sigmoid')(discriminator)

    dis_opt = Adam(lr=0.0002, beta_1=0.5)
    discriminator_model = Model(input = dis_input, output = discriminator)
    discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy'])
    discriminator_model.summary()

    return discriminator_model

Полная GAN

Чтобы генератор проверял выход, скомпилируем сеть в Keras. В этой сети входом будет случайный шум для генератора, а выходом – выход генератора «скормленный» дискриминатору. Во избежание «коллапса состязания» весовые коэффициенты дискриминатора заморожены.

Код

discriminator.trainable = False

opt = Adam(lr=0.00015, beta_1=0.5) #same as generator

gen_inp = Input(shape=noise_shape)
GAN_inp = generator(gen_inp)
GAN_opt = discriminator(GAN_inp)

gan = Model(input = gen_inp, output = GAN_opt)
gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics=['accuracy'])
gan.summary()

Аниме и генеративно-состязательная сеть: в чём связь?

Тренировка модели

Генеративно-состязательная сеть

Базовая конфигурация модели

Аниме и генеративно-состязательная сеть: в чём связь?

1. Сгенерируйте случайный нормальный шум для входа:

def gen_noise(batch_size, noise_shape):
    return np.random.normal(0, 1, size=(batch_size,)+noise_shape)

2. Объедините реальные данные из набора с шумом:

data_X = np.concatenate([real_data_X, fake_data_X])
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2

data_Y = np.concatenate([real_data_Y, fake_data_Y])

3. Подайте шум на вход:

real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2

4. Тренируем только генератор

    print("Begin step: ", tot_step)
    step_begin_time = time.time() 
    
    real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir)
    
    noise = gen_noise(batch_size,noise_shape)
    
    fake_data_X = generator.predict(noise)
    
    if (tot_step % 100) == 0:
        step_num = str(tot_step).zfill(4)
        save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")

5. Тренируем только дискриминатор:

    discriminator.trainable = True
    generator.trainable = False

    dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y) 
    dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y) 
    
    print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))

6. Тренируем совмещённую GAN:

    generator.trainable = True
    discriminator.trainable = False

    GAN_X = gen_noise(batch_size,noise_shape)
    GAN_Y = real_data_Y
   
    gan_metrics = gan.train_on_batch(GAN_X,GAN_Y)
    print("GAN loss: %f" % (gan_metrics[0]))
    
    text_file = open(log_dir+"\\training_log.txt", "a")
    text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0]))
    text_file.close()

    avg_GAN_loss.append(gan_metrics[0])
            
    end_time = time.time()
    diff_time = int(end_time - step_begin_time)
    print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))

7. Сохраните экземпляры дискриминатора и генератора:

if ((tot_step+1) % 500) == 0:
        print("-----------------------------------------------------------------")
        print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss))) 
        print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss))) 
        print("Average GAN loss: %f" % (np.mean(avg_GAN_loss)))
        print("-----------------------------------------------------------------")
        discriminator.trainable = False
        generator.trainable = False
        # predict on fixed_noise
        fixed_noise_generate = generator.predict(noise)
        step_num = str(tot_step).zfill(4)
        save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png")
        generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5")
        discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")

Результаты Манга-генератора

После 10000 шагов обучения результат выглядит круто! Смотрите сами.

Генеративно-состязательная сеть Аниме и генеративно-состязательная сеть: в чём связь?

Более длительная тренировка с большим набором данных приведёт к лучшим результатам. (Некоторые лица получились страшными, это правда :D)

Заключение

Наверняка задача генерации лиц в стиле аниме интересна.

Но в ней ещё есть место для улучшений: лучшее обучение, модели и набор данных. Наша модель вряд ли заставит человека задуматься, реальны ли сгенерированные модели. Тем не менее, она отлично справляется с поставленной задачей. Вы можете улучшить сеть, используя персонажей во весь рост.

Исходный код проекта доступен на GitHub.

А что бы вы улучшили в этой нейронной сети?

5020

Комментарии

Рекомендуем

BUG!