TLDR
Los métodos de IA generativa para la generación de imágenes tienen una amplia variedad de aplicaciones potenciales en marketing, ventas y comercio electrónico. Con estas aplicaciones en mente, el equipo de Salesforce Research ha desarrollado varias técnicas basadas en modelos de difusión de generación de imágenes, incluyendo métodos para la edición de imágenes, guía mejorada de clasificadores y métodos mejorados de generación controlada. En esta entrada de blog, documentamos nuestra experiencia de entrenamiento de grandes modelos de difusión de texto a imagen desde cero. Describimos nuestras decisiones de diseño, proceso de entrenamiento y métricas de rendimiento para la primera generación de nuestros modelos de generación de imágenes, llamada XGen-Image-1. En resumen:
- Entrenamos XGen-Image-1, un modelo de difusión latente de 860 millones de parámetros, utilizando 1.100 millones de imágenes disponibles públicamente del conjunto de datos LAION
- Combinando un modelo latente VAE con pixel upsamplers permite el entrenamiento a muy baja resolución, reduciendo el coste computacional.
- Un modelo de generación de imágenes competitivo puede entrenarse en la pila TPU de Google por ~75.000 dólares
- XGen-Image-1 iguala el rendimiento de alineación puntual de Stable Diffusion 1.5 y 2.1, que se encuentran entre los modelos de generación de imágenes con mejor rendimiento.
- Los retoques automatizados en regiones apuntadas (por ejemplo. «cara») son mejoras eficaces
- El muestreo de rechazo en el momento de la inferencia puede mejorar drásticamente los resultados
¿Qué vamos a entrenar?
En la actualidad, existen dos clases principales de modelos de difusión: basados en píxeles y basados en latencia. Los modelos de difusión basados en píxeles incluyen DeepFloyd-IF (Shonenkov et al. 2023), Imagen (Saharia et al. 2022), EDiffi (Balaji et al. 2022), Kandinsky (Shakhmatov et al. 2022) y DALLE-2 (Ramesh et al. 2022). Los modelos de difusión basados en latencia (LDM) incluyen la familia de modelos de difusión estable, de la que fueron pioneros Rombach et al. (2021) y Wuerstchen (Pernias et al. 2023).
La diferencia clave entre estas clases de modelos es que los modelos de difusión latente son representaciones de imágenes autocodificadas de denoising en un espacio comprimido (normalmente 8x espacialmente) mientras que los modelos de difusión basados en píxeles operan directamente sobre píxeles.
Las imágenes más grandes son caras de entrenar, por lo que la mayoría de estos enfoques normalmente acaban entrenando un modelo «base» a una resolución de 64×64. Para los modelos de píxeles, esto da como resultado imágenes bastante pequeñas, como se ve a continuación. El contenido es observable, pero no el detalle, y no se verá bien cuando se amplíe. Por ello, es necesario seguir la generación de la base con amplificadores (normalmente también modelos de difusión).
Para todos estos enfoques, se utilizan varios modelos preentrenados
- Condicionamiento del texto: CLIP/T5 language embeddings
- Modelos de difusión latente: VAE
También en el caso de los modelos de muestreo ascendente en cascada, cada componente se entrena de forma independiente, es decir, cada componente se preentrena con respecto a los demás (y viceversa)
Los autocodificadores (y los muestreadores ascendentes) entrenados de forma diversa son extremadamente adaptables a diferentes tipos de imágenes. Normalmente no se necesitan cambios en comparación con el cambio necesario en el proceso generativo. Esto no quiere decir que estos modelos sean perfectos; los artefactos de VAE StableDiffusion, como las características detalladas (por ejemplo, texto, caras pequeñas), son de sobra conocidos y los upsamplers pueden introducir sus propios artefactos. Estos modelos no son perfectos, pero son mucho más reutilizables que los modelos generativos de base.
Con esto en mente, decidimos que los problemas más directos del remuestreo, el condicionamiento y la codificación no tenían por qué ser el centro de nuestro modelo. En su lugar, nos preguntamos: puesto que estos elementos son robustos y reutilizables, ¿hasta qué punto podemos reutilizarlos?
Decidimos poner a prueba los límites de un entrenamiento eficiente y ver a qué baja resolución podíamos entrenar combinando modelos preentrenados de autocodificación y muestreo ascendente de píxeles.
Como se ilustra en la tubería anterior, utilizamos tanto un autoencoder preentrenado como upsamplers opcionales basados en píxeles. Esto nos permite generar imágenes con una resolución (latente) baja (32×32) y obtener imágenes de 1024×1024. En el futuro, queremos explorar aún más el límite inferior de la resolución práctica. A efectos de resultados, informamos de los resultados cuantitativos y la evaluación humana en la etapa de 256×256 directamente después de la VAE sin upsamplers. Los resultados cualitativos (al principio y al final de este post) utilizan tanto un «re-upsampler» de 256→64→256 (análogo al «Refiner» de SDXL) como un upsampler de 256→1024.
¿Con qué datos lo entrenaremos?
Siguiendo Stable Diffusion, entrenamos nuestro modelo utilizando el conjunto de datos LAION-2B con un filtro de puntuación estética de 4.5, que constituye ~42% del conjunto de datos. El conjunto de datos LAION se compone de imágenes extraídas de la web; echemos un vistazo rápido al aspecto de los pies de foto y las imágenes. Como se ve a continuación, hay muchas imágenes de productos con descripciones básicas, mucha ropa, etc. Sin embargo, lo bueno de los grandes conjuntos de datos es que en la larga cola de conceptos podemos encontrar casos de sustantivos muy raros. Si multiplicamos esta larga cola por la escala de los conjuntos de datos, ¡todavía hay muchos casos! Por ejemplo, en una muestra de 1 millón de imágenes, las palabras «dragón» y «astronauta» aparecen en 411 y 86 casos, respectivamente. Si tenemos en cuenta todo el conjunto de datos de imágenes 2B, eso significa que hay >800k y >170k instancias de entrenamiento etiquetadas de cada concepto (¡sin mencionar las variantes de esas palabras)! Para que nos hagamos una idea, esto significa que hay casi tantas instancias de «dragón» para aprender como imágenes en el desafío original ImageNet-1k total.
Infraestructura de entrenamiento
Entrenamos nuestro modelo en TPU v4s. Encontramos que el código de TPU Moco-V3 de Ronghang Hu era un punto de partida inestimable. Como parte del entrenamiento en TPUs, usamos Google Cloud Storage (GCS) para guardar modelos y usamos unidades montadas en gcloud para almacenar grandes conjuntos de datos. Entrenamos nuestro modelo en una máquina TPU v4-512 para 1,1 millones de pasos, lo que nos llevó unos 9 días con unos costes de hardware estimados de aproximadamente 73.000 dólares. El StableDiffusion original costó $600k.
Hipo en el entrenamiento
Al principio, nuestras pérdidas variaban mucho de un paso a otro, incluso con lotes de gran tamaño. Descubrimos que esto se debía a que todos los trabajadores recibían la misma siembra. Mediante la siembra aleatoria de los trabajadores con su rango, se logró una distribución uniforme de los pasos de ruido y curvas de pérdida más suaves.
Guardar los puntos de control del modelo resultó ser un problema sorprendentemente peliagudo en nuestra configuración de infraestructura. Los directorios locales ~ en TPUs no son persistentes, y nuestra unidad de código tenía I / O lento. Guardar en GCS en paralelo no funcionó de inmediato – con Pytorch/XLA hay que tener cuidado con lo que se está ejecutando por todos los trabajadores frente a sólo el maestro. En este caso algunas operaciones tomarán un bloqueo en la entrada GCS resultando en que las otras se cuelguen.
El bloque de código de abajo resuelve este problema guardando en GCS mientras sólo toca el cubo (sacando un bloqueo) en el hilo maestro.
if 'gs://' in archivo_o_ruta:
print("Haciendo blob")
gcs_path = file_or_path.replace('gs://', '')
nombre_cubo = gcs_path.split('/')[0]
cliente_almacenamiento = cliente_almacenamiento()
bucket = storage_client.bucket(nombre_bucket)
blob = bucket.blob('/'.join(gcs_path.split('/')[1:])
print("Abriendo blob")
print("iniciando bloque 'maestro'")
if debe_escribir_datos:
with blob.open('wb', ignore_flush=True) as f:
print("Guardando realmente")
torch.save(cpu_data, f)
Curvas de pérdidas: esperamos que un mayor entrenamiento siga mejorando nuestro modelo.
Evaluación automatizada de métricas
Realizamos una evaluación automática de nuestro modelo (y de los puntos de control) midiendo la puntuación CLIP (alineación con la indicación) en el eje x y la FID (similitud de apariencia a nivel de conjunto de datos) en el eje y. Estas métricas se calculan a través de 15 puntos de control. Estas métricas se calculan en 15 escalas de orientación para 30.000 pares imagen-pista en la primera figura (frente a las versiones StableDiffusion) y 1.000 pares para la comparación entre puntos de control. Los pares de datos se extraen aleatoriamente del conjunto de datos COCO Captions con «A photograph of» añadido al pie de foto para evitar penalizaciones FID asociadas con diferentes estilos gráficos (por ejemplo, ilustraciones).
Las evaluacionesCLIP-FID son limitadas, como se indica en la literatura (por ejemplo, SDXL, Podell et al. 2023), pero siguen siendo una métrica útil a gran escala. Vemos que nuestro modelo compite con SD1.5 y SD2.1, superando de hecho a los modelos StableDiffusion preentrenados en ambas métricas, lo que indica un alto fotorrealismo y una pronta fidelidad. Como prueba de cordura, observamos que las «épocas» secuenciales (12,5k pasos) suelen mejorar en ambas dimensiones.
Evaluación en humanos
Siguiendo a SDXL (Podell et al. 2023), realizamos una evaluación humana de nuestro modelo frente a SD1.5 y 2.1 en la prueba PartiPrompt (Yu et al. 2022), que mide la alineación de instrucciones, utilizando Amazon Mechanical Turk. Preguntamos a los usuarios «¿Cuál de las imágenes sigue mejor la indicación?» recogiendo respuestas para las 1632 indicaciones de la prueba comparativa en 6 pruebas separadas, resultando en ~10k respuestas en total por comparación. Las barras de error indican intervalos de confianza del 95%.
En la figura anterior mostramos la media general (de todas las secciones). Vemos que XGen-Image se clasifica de forma casi idéntica a SD1.5 mientras que está marginalmente (aunque no significativamente) por detrás de SD2.1.
No evaluamos directamente contra el reciente SDXL (Podell et al. 2023) que es un LDM mucho más grande que demuestra un rendimiento muy superior a SD1.x y 2.x. Actualmente estamos trabajando en escalar XGen-Image y abordar áreas específicas de mejora, con el objetivo de igualar el rendimiento de SDXL.
Generación consistente de imágenes de alta calidad
A partir de nuestro modelo entrenado XGen-Image, hemos implementado dos trucos comúnmente utilizados para la generación consistente de imágenes de alta calidad en nuestro proceso de inferencia.
- Generar un montón de imágenes y elegir la mejor
- Pintar las cosas que no se ven bien
Queríamos mantener la configuración de 1-prompt 1-output, así que buscamos automatizar lo anterior.
Para (1), probamos el muestreo de rechazo – generar múltiples imágenes y seleccionar automáticamente la mejor. Inicialmente exploramos la puntuación estética y la puntuación CLIP, pero encontramos que PickScore (Kirstain et al. 2023) proporcionaba una gran métrica general que, como se indica en su artículo, se correlacionaba bien con la preferencia humana.
Para hacer estos lotes eficientes utilizamos media precisión, atención eficiente y el programador PNDM (Liu et al. 2022). Esto nos permite generar 32 imágenes (4×8) en ~5 segundos en una GPU A100. Como se muestra en el siguiente ejemplo, la tasa de éxito de una generación alineada con un prompt no siempre va a ser del 0% o del 100% – permitiendo múltiples oportunidades para que el modelo acierte y siendo capaces de automáticamente determinar un buen candidato podemos mejorar la capacidad del pipeline global.
Como ejemplo de (2), aplicamos un proceso bastante estándar para la mejora regional de las caras (aunque se generaliza a cualquier objeto)
- Obtener máscara de segmentación para un objeto (desde prompt)
- Recortar en base a la máscara de segmentación
- Ampliar recorte
- Ejecutar img2img sobre recorte con un título que coincida/paralelamente a la pregunta de segmentación (para caras segmentamos con «una cara» e img2img con «una fotografía de una cara»)
- Utilizar la máscara de segmentación para mezclar el recorte aumentado con la imagen original
Evaluación cualitativa del muestreo de rechazo
Vemos que el muestreo de rechazo automático mediante PickScore mejora drásticamente el rendimiento de XGen-Image, induciendo una brecha mayor que cualquier diferencia de modelo con las versiones StableDiffusion. Aquí se evalúan todos los PartiPrompts en una sola prueba.
One More Collage
Conclusión
En este post hemos presentado XGen-Image-1, unun modelo de difusión latente texto-imagen entrenado para reutilizar varios componentes preentrenados. Nuestro prototipo sigue en gran medida el proceso LDM/SD, pero sólo se entrena con una resolución de 256×256 píxeles (32×32 latentes), lo que reduce el coste computacional. El prototipo XGen-Image se comporta de forma similar a Stable Diffusion 1.5 y 2.1 en la evaluación. Descubrimos que el uso de PickScore para realizar el muestreo de rechazo por lotes mejoraba drásticamente las generaciones, medido tanto por el rendimiento humano como por las métricas automáticas. Estamos entusiasmados por seguir desarrollando XGen-Image y compartir nuestras observaciones a lo largo del camino.
Desglose de contribuciones
Bram Wallace: Biblioteca de código, formación de modelos, canal de inferencia
Akash Gokul: Prototipos iniciales, ayuda con la codificación, mejoras en la velocidad del upsampler
Dongxu Li: Recogida de datos, formateo y carga del pipeline
Junnan Li: Asesoramiento en el entrenamiento del modelo y ayuda con el pipeline de datos
Nikhil Naik: Planificación, supervisión y gestión del proyecto
Damos las gracias a Srinath Reddy Meadusani y Lavanya Karanam por su ayuda y apoyo con la infraestructura informática. También agradecemos a Ran Xu, Ning Yu, Shu Zhang y Kathy Baxter sus sugerencias en diferentes etapas del proyecto. Por último, agradecemos a Caiming Xiong y Silvio Savarse por sus consejos y apoyo a lo largo del proyecto.