Аниме и генеративно-состязательная сеть: в чём связь?

Генеративно-состязательная сеть, которую вы построите, создаёт персонажей из манги и аниме. Рисуйте вайфу в своё удовольствие!

Аниме и генеративно-состязательная сеть: в чём связь?

Давно хотели создать своих Аску, Код 002 или Канеки Кена? У вас появилась отличная возможность это сделать :)

Что такое генеративно-состязательная сеть?

Лучший вывод, который может генерировать нейронная сеть, похож на человеческий. Образно генеративно-состязательная сеть (GAN) может даже обмануть человека, заставив его думать, что вывод сделан им самим.

В генеративно-состязательных сетях две сети соревнуются друг с другом, что приводит к взаимным импровизациям. Генератор обманывает дискриминатор, создавая ложные входы и выдавая их за реальные. Дискриминатор сообщает, является ли ввод реальным или ложным.

Аниме и генеративно-состязательная сеть: в чём связь?

В обучении GAN есть три главных шага:

    1. Используйте генератор для создания ложных входов из случайного шума.
    2. Обучите дискриминатор на ложных и реальных входах (одновременно с объединением или поочерёдно, что предпочтительнее).
    3. Обучите всю модель: дискриминатор + генератор.

Помните, что весовые коэффициенты дискриминатора «заморожены» во время последнего шага.

Причина сочетания обоих сетей состоит в отсутствии обратной связи на выходах генератора. Единственный ориентир – если дискриминатор принимает выходы генератора.

Генеративно-состязательная сеть

Можно сказать, что они соперничают друг с другом. Генератор обучается во время схватки с «соперником», чтобы реализовать цель.

Наша сеть

Для задачи мы используем глубокую сверточную генеративно-состязательную сеть.

Вот несколько рекомендаций для таких сетей:

  1. Замените максимальные подвыборки шагами свёртки.
  2. Используйте перемещённую свёртку для повышения частоты дискретизации.
  3. Устраните полностью соединённые слои.
  4. Используйте пакетную нормализацию, кроме выходного слоя генератора и входного слоя дискриминатора.
  5. Используйте ReLU в генераторе, кроме выхода, который использует tanh.
  6. Leaky ReLU в дискриминаторе.

Детали сетапа

  • версия Keras==2.2.4
  • TensorFlow==1.8.0
  • Jupyter Notebook
  • Matplotlib и другие библиотеки типа NumPy, Pandas
  • Python==3.5.7

Набор данных

Набор данных для лиц из аниме можно собрать на тематических сайтах, скачивая картинки и вырезая лица. Код Python, который демонстрирует это.

Также доступны обработанные и обрезанные лица.

Аниме и генеративно-состязательная сеть: в чём связь?

Генератор

Он состоит из сверточных слоёв, пакетной нормализации и функции активации Leaky ReLU для повышения частоты дискретизации. Мы используем параметр шагов в сверточном слое, чтобы избежать нестабильной обучаемости GAN. Функция не будет равно нулю, если x < 0, вместо этого Leaky ReLU имеет небольшое отрицательное отклонение (0.01 и так далее).

Аниме и генеративно-состязательная сеть: в чём связь?

Код

def get_gen_normal(noise_shape):
    kernel_init = 'glorot_uniform'    
    gen_input = Input(shape = noise_shape) 
    
    generator = Conv2DTranspose(filters = 512, kernel_size = (4,4), strides = (1,1), padding = "valid", data_format = "channels_last", kernel_initializer = kernel_init)(gen_input)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2D(filters = 64, kernel_size = (3,3), strides = (1,1), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    generator = BatchNormalization(momentum = 0.5)(generator)
    generator = LeakyReLU(0.2)(generator)
    
    generator = Conv2DTranspose(filters = 3, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(generator)
    
    generator = Activation('tanh')(generator)
    
    gen_opt = Adam(lr=0.00015, beta_1=0.5)
    generator_model = Model(input = gen_input, output = generator)
    generator_model.compile(loss='binary_crossentropy', optimizer=gen_opt, metrics=['accuracy'])
    generator_model.summary()

    return generator_model

Дискриминатор

Он также состоит из сверточных слоёв, где мы используем шаги для понижения частоты дискретизации и пакетной нормализации для стабильности.

Аниме и генеративно-состязательная сеть: в чём связь?

Код

def get_disc_normal(image_shape=(64,64,3)):
    dropout_prob = 0.4
    kernel_init = 'glorot_uniform'
    dis_input = Input(shape = image_shape)
    
    discriminator = Conv2D(filters = 64, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(dis_input)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Conv2D(filters = 128, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
    discriminator = BatchNormalization(momentum = 0.5)(discriminator)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Conv2D(filters = 256, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
    discriminator = BatchNormalization(momentum = 0.5)(discriminator)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Conv2D(filters = 512, kernel_size = (4,4), strides = (2,2), padding = "same", data_format = "channels_last", kernel_initializer = kernel_init)(discriminator)
    discriminator = BatchNormalization(momentum = 0.5)(discriminator)
    discriminator = LeakyReLU(0.2)(discriminator)

    discriminator = Flatten()(discriminator)

    discriminator = Dense(1)(discriminator)

    discriminator = Activation('sigmoid')(discriminator)

    dis_opt = Adam(lr=0.0002, beta_1=0.5)
    discriminator_model = Model(input = dis_input, output = discriminator)
    discriminator_model.compile(loss='binary_crossentropy', optimizer=dis_opt, metrics=['accuracy'])
    discriminator_model.summary()

    return discriminator_model

Полная GAN

Чтобы генератор проверял выход, скомпилируем сеть в Keras. В этой сети входом будет случайный шум для генератора, а выходом – выход генератора «скормленный» дискриминатору. Во избежание «коллапса состязания» весовые коэффициенты дискриминатора заморожены.

Код

discriminator.trainable = False

opt = Adam(lr=0.00015, beta_1=0.5) #same as generator

gen_inp = Input(shape=noise_shape)
GAN_inp = generator(gen_inp)
GAN_opt = discriminator(GAN_inp)

gan = Model(input = gen_inp, output = GAN_opt)
gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics=['accuracy'])
gan.summary()

Аниме и генеративно-состязательная сеть: в чём связь?

Тренировка модели

Генеративно-состязательная сеть

Базовая конфигурация модели

Аниме и генеративно-состязательная сеть: в чём связь?

1. Сгенерируйте случайный нормальный шум для входа:

def gen_noise(batch_size, noise_shape):
    return np.random.normal(0, 1, size=(batch_size,)+noise_shape)

2. Объедините реальные данные из набора с шумом:

data_X = np.concatenate([real_data_X, fake_data_X])
real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2

data_Y = np.concatenate([real_data_Y, fake_data_Y])

3. Подайте шум на вход:

real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
fake_data_Y = np.random.random_sample(batch_size)*0.2

4. Тренируем только генератор

    print("Begin step: ", tot_step)
    step_begin_time = time.time() 
    
    real_data_X = sample_from_dataset(batch_size, image_shape, data_dir=data_dir)
    
    noise = gen_noise(batch_size,noise_shape)
    
    fake_data_X = generator.predict(noise)
    
    if (tot_step % 100) == 0:
        step_num = str(tot_step).zfill(4)
        save_img_batch(fake_data_X,img_save_dir+step_num+"_image.png")

5. Тренируем только дискриминатор:

    discriminator.trainable = True
    generator.trainable = False

    dis_metrics_real = discriminator.train_on_batch(real_data_X,real_data_Y) 
    dis_metrics_fake = discriminator.train_on_batch(fake_data_X,fake_data_Y) 
    
    print("Disc: real loss: %f fake loss: %f" % (dis_metrics_real[0], dis_metrics_fake[0]))

6. Тренируем совмещённую GAN:

    generator.trainable = True
    discriminator.trainable = False

    GAN_X = gen_noise(batch_size,noise_shape)
    GAN_Y = real_data_Y
   
    gan_metrics = gan.train_on_batch(GAN_X,GAN_Y)
    print("GAN loss: %f" % (gan_metrics[0]))
    
    text_file = open(log_dir+"\\training_log.txt", "a")
    text_file.write("Step: %d Disc: real loss: %f fake loss: %f GAN loss: %f\n" % (tot_step, dis_metrics_real[0], dis_metrics_fake[0],gan_metrics[0]))
    text_file.close()

    avg_GAN_loss.append(gan_metrics[0])
            
    end_time = time.time()
    diff_time = int(end_time - step_begin_time)
    print("Step %d completed. Time took: %s secs." % (tot_step, diff_time))

7. Сохраните экземпляры дискриминатора и генератора:

if ((tot_step+1) % 500) == 0:
        print("-----------------------------------------------------------------")
        print("Average Disc_fake loss: %f" % (np.mean(avg_disc_fake_loss))) 
        print("Average Disc_real loss: %f" % (np.mean(avg_disc_real_loss))) 
        print("Average GAN loss: %f" % (np.mean(avg_GAN_loss)))
        print("-----------------------------------------------------------------")
        discriminator.trainable = False
        generator.trainable = False
        # predict on fixed_noise
        fixed_noise_generate = generator.predict(noise)
        step_num = str(tot_step).zfill(4)
        save_img_batch(fixed_noise_generate,img_save_dir+step_num+"fixed_image.png")
        generator.save(save_model_dir+str(tot_step)+"_GENERATOR_weights_and_arch.hdf5")
        discriminator.save(save_model_dir+str(tot_step)+"_DISCRIMINATOR_weights_and_arch.hdf5")

Результаты Манга-генератора

После 10000 шагов обучения результат выглядит круто! Смотрите сами.

Генеративно-состязательная сеть Аниме и генеративно-состязательная сеть: в чём связь?

Более длительная тренировка с большим набором данных приведёт к лучшим результатам. (Некоторые лица получились страшными, это правда :D)

Заключение

Наверняка задача генерации лиц в стиле аниме интересна.

Но в ней ещё есть место для улучшений: лучшее обучение, модели и набор данных. Наша модель вряд ли заставит человека задуматься, реальны ли сгенерированные модели. Тем не менее, она отлично справляется с поставленной задачей. Вы можете улучшить сеть, используя персонажей во весь рост.

Исходный код проекта доступен на GitHub.

А что бы вы улучшили в этой нейронной сети?

МЕРОПРИЯТИЯ

Комментарии

ВАКАНСИИ

Добавить вакансию
Golang разработчик
Москва, от 150000 RUB до 200000 RUB
Middle IOS разработчик
Москва, от 220000 RUB до 320000 RUB
Менеджер продукта
Москва, по итогам собеседования

ЛУЧШИЕ СТАТЬИ ПО ТЕМЕ