4.Construimos el discriminador
- Hola a todos. En este curso veremos como funcionan las redes de confrontación generativa condicional (cGAN).Estas redes a partir de una imagen de entrada generan una imagen de salida. El curso esta basado en este tutorial de google: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
- En este capítulo vemos como crear nuestro discriminador de imágenes. En nuestro caso el discriminador observará el resultado generado y nos dirá si es cierto o falso , en relación a lo que el considera una imagen correcta. El discriminador en este caso es una PatchGAN: En vez de devolver cierto o falso devuelve una cuadricula en el que nos dice por si la parte de la imagen corespondiente a cada cuadricula es verdadedo o falso.
- Código visto en el vídeo:
#Cada bloque en el discriminador es (Conv -> BatchNorm -> Leaky ReLU)
#El discriminador recibira entradas.
#Imagen de entrada y la imagen de destino, que debe clasificar como real.
#Imagen de entrada y la imagen generada (salida del generador), que debe clasificar como #falsa.
#Concatenamos estas 2 entradas juntas en el código (tf.concat ([inp, tar], axis = -1))
def Discriminator():
initializer = tf.random_normal_initializer(0., 0.02)
inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')#imagen de entrada real
tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')#Imagen del generador
#concatenamos las dos entradas
x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
#Construimos los distintos bloques
down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
#Añadimos la capa convolucional
conv = tf.keras.layers.Conv2D(512, 4, strides=1,
kernel_initializer=initializer,
use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
#Añadimos la capa BatchNorm
batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
#Añadimos la capa Leaky ReLU
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
#Última capa con único filtro de salida con un canal donde nos dice por cada pixel
#de la imagen si parece real o no
last = tf.keras.layers.Conv2D(1, 4, strides=1,
kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
#devolvemos el modelo del discriminador
return tf.keras.Model(inputs=[inp, tar], outputs=last)
#compruebo que todo funciona
discriminator = Discriminator()
#Pasamos dos imagenes y vemos un posible resultado.
#Recordar que todavia no hemos entrenado al modelo
disc_out = discriminator([inp[tf.newaxis,...], gen_output], training=False)
plt.imshow(disc_out[0,...,-1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
#objeto para evaluar el resultado de las imágenes que vayamos obteniendo
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
#Función para evaluar el comportamiento del discriminador
def discriminator_loss(disc_real_output, disc_generated_output):
#Diferencia entre la observación de una imagen real y una matriz con todo a unos(representa
#que la imagen es real)
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
#Diferencia entre la observación de una imagen generada y una matriz con todo a #ceros(representa que la imagen es fake)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
#La suma evalua el comportamiento del discriminador
total_disc_loss = real_loss + generated_loss
return total_disc_loss
LAMBDA = 100
#Función para evaluar el comportamiento del generador.
#Le pasamos como parametros el mapa generado por el dicriminador, la imagen generada,
#y la imagen real
def generator_loss(disc_generated_output, gen_output, target):
#Lo contrario al discminador
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
#error absoluto medio
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
#Para la suma además utilizamos el hiperparametro LAMBDA con valor 100
total_gen_loss = gan_loss + (LAMBDA * l1_loss)
return total_gen_loss
#Definimos los optimizadores
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
#guardamos checkpoints para permitir reanudar el entrenamiento donde nos quedamos
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
No hay comentarios:
Publicar un comentario