2.Elegimos nuestro dataset y empezamos con el código
- Hola a todos. En este curso veremos como funcionan las redes de confrontación generativa condicional (cGAN).Estas redes a partir de una imagen de entrada generan una imagen de salida. El curso esta basado en este tutorial de google: https://www.tensorflow.org/beta/tutorials/generative/pix2pix
- Antes de comenzar con la implementación de nuestro código es necesario tener claro el problema que queremos solucionar en nuestro caso queremos generar fachadas a partir de bocetos. Hemos utilizado el dataset que google nos proporciona. En el siguiente vídeo lo explicamos mejor:
- Vamos a empezar a con el código.
#importamos las librerias necesarias.
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
import tensorflow as tf #IMPORTANTE TENER VERSIÓN 2.0
print(tf.__version__ )
#En esta porción de código descargamos el dataset de imagenes
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'
#Ruta local donde descomprimosmos el fichero zip con el dataset de imágenes
path_to_zip = tf.keras.utils.get_file('facades.tar.gz',origin=_URL,extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
print(PATH)
#definimos algunas constantes que utilizaremos posteriormente
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256 #Ancho de las imagenes
IMG_HEIGHT = 256 #Alto de las imagenes
#función que carga las imágenes y las devuelve
def load(image_file):
image = tf.io.read_file(image_file) #carga la imagen de disco
image = tf.image.decode_jpeg(image) # la decodificamos a jpg
w = tf.shape(image)[1]
w = w // 2
real_image = image[:, :w, :]
input_image = image[:, w:, :]
input_image = tf.cast(input_image, tf.float32) #cast para pasarlas a float para los calculos
real_image = tf.cast(real_image, tf.float32)
return input_image, real_image
#cargamos imagen de entrenamiento
inp, re = load(PATH+'train/2.jpg')
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)
#Función para redimensionar las imágenes.
#Se le pasan como paramteros las imágenes y sus tamaños.Llamamos
#al método de TS resize y las devuelve.
def resize(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image
#función que coge una parte de la imágenes pasadas como parámetros
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]
# Normalizamos.Queremos que las imágenes estén en el rango de
#[-1,1].Como las imágenes son de de tamaño 256 dividimos entre 127.5 y
#restamos 1
def normalize(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image
#Funcíon que se utiliza para aumentación de datos.Aplicancdo está función se
#generan virtualmente más imágenes a base de ampliar las imagenes y desplazarlas.
@tf.function()
def random_jitter(input_image, real_image):
# aumentamos el tamaño 286 x 286 x 3 (canales de color)
input_image, real_image = resize(input_image, real_image, 286, 286)
# cogemos una parte de la imagen (256 x 256 x 3) anteriormente ampliada a 286 x 286 x 3
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5:
# Volteamos la imagen horizontalmente
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image
# Probamos las funciones
# 1. Cambiar el tamaño de una imagen a mayor altura y ancho
# 2. Recorte aleatoriamente al tamaño original
# 3. Voltear aleatoriamente la imagen horizontalmente
plt.figure(figsize=(6, 6))
for i in range(4):
rj_inp, rj_re = random_jitter(inp, re)
plt.subplot(2, 2, i+1)
plt.imshow(rj_inp/255.0)
plt.axis('off')
plt.show()
#función para cargar imágenes de entrenamiento
def load_image_train(image_file):
input_image, real_image = load(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
#función para cargar imagen de test.En la función de test NO aplicamos
#el aumento de datos (función random_jitter)
def load_image_test(image_file):
input_image, real_image = load(image_file)
input_image, real_image = resize(input_image, real_image,
IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
No hay comentarios:
Publicar un comentario