5.Entrenamos el modelo
- Hola a todos. En este curso veremos como funcionan las redes de confrontación generativa condicional (cGAN).Estas redes a partir de una imagen de entrada generan una imagen de salida. El curso esta basado en este tutorial de google: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
- En este capítulo entrenamos el modelo.Para entrenar el modelo comenzamos iterando sobre el conjunto de datos, después el generador obtiene la imagen de entrada y nosotros obtenemos una salida generada. El discriminador recibe la imagen de entrada y la imagen generada como la primera entrada. La segunda entrada es input_image y target_image. A continuación, calculamos el generador y la pérdida discriminadora. Luego, calculamos los gradientes de pérdida con respecto a las variables (entradas) del generador y del discriminador y las aplicamos al optimizador. En el siguiente vídeo lo explicamos:
- Vemos el el código del vídeo:
#Con la siguiente función evaluamos el comportamineto del modelo generador que estamos #entrenando le pasamos imágenes de nuestro conjunto de prueba
EPOCHS = 150
def generate_images(model, test_input, tar):
# the training=True is intentional here since
# we want the batch statistics while running the model
# on the test dataset. If we use training=False, we will get
# the accumulated statistics learned from the training dataset
# (which we don't want)
prediction = model(test_input, training=True)
plt.figure(figsize=(15,15))
display_list = [test_input[0], tar[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Predicted Image']
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
# getting the pixel values between [0, 1] to plot it.
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
#Función que enlaza todos los modulos implementados
@tf.function
def train_step(input_image, target):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
#El generador toma la imagen de entrada de la función.Esta imagen es comprimida por
#el encoder y descomprimida por el decoder y obtenemos imagen de salida
gen_output = generator(input_image, training=True)
#El discriminador observa lo que ha generado el generador y obtenemos
#la salida del discriminador
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
#Llamamos a las funciones generator_loss y discriminator_loss, que evaluan
#el comportamiento del generador y del discriminador
gen_loss = generator_loss(disc_generated_output, gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
#Variable para guardar los gradientes del generador
generator_gradients = gen_tape.gradient(gen_loss,
generator.trainable_variables)
#Variable para guardar los gradientes del discrimindaor
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
#Optimizamos
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.trainable_variables))
#Definimos la rutina de entrenamiento en la siguiente función.
#A la función le pasamos los dataset de entrenamiento y de test y el número de épocas.
def fit(train_ds, epochs, test_ds):
for epoch in range(epochs):
start = time.time()
# Bucle para el entrenamiento
for input_image, target in train_ds:
train_step(input_image, target)
clear_output(wait=True)
for example_input, example_target in test_ds.take(1):
generate_images(generator, example_input, example_target)
# Cada 20 epocas salvamos el modelo
if (epoch + 1) % 20 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
#Entrenamos al modelo.En mi caso algo más de un día
fit(train_dataset, EPOCHS, test_dataset)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# Run the trained model on the entire test dataset
for inp, tar in test_dataset.take(30):
generate_images(generator, inp, tar)