23. Entrenamiento de un clasificador de imágenes

23.1 Introducción

En el tema anterior construimos una CNN en PyTorch. Definimos sus capas, entendimos el papel de nn.Module y verificamos que el modelo podía recibir imágenes y producir una salida con la forma esperada. Pero todavía faltaba lo más importante: hacer que aprenda.

Entrenar un clasificador de imágenes significa ajustar los pesos del modelo para que, a partir de ejemplos etiquetados, aprenda a asociar patrones visuales con clases correctas. Esa etapa requiere mucho más que una arquitectura: necesitamos datos, una función de pérdida, un optimizador y un ciclo de entrenamiento bien planteado.

En este tema veremos la lógica completa del entrenamiento en PyTorch. La idea es comprender no solo qué líneas de código hay que escribir, sino también qué está ocurriendo conceptualmente en cada paso.

23.2 ¿Qué significa entrenar?

Entrenar un modelo significa exponerlo repetidamente a ejemplos de entrada y comparar sus predicciones con las etiquetas correctas. A partir de ese error, el algoritmo ajusta los pesos para que futuras predicciones sean mejores.

En clasificación de imágenes, el proceso general es:

  • Tomar un batch de imágenes.
  • Pasarlo por la red.
  • Obtener una predicción.
  • Compararla con la clase correcta.
  • Calcular una pérdida.
  • Actualizar los parámetros.

Este ciclo se repite muchas veces hasta que el modelo mejora su desempeño.

Entrenar no es “ejecutar la red”. Es corregirla una y otra vez a partir del error que comete sobre los datos.

23.3 Ingredientes básicos del entrenamiento

Para entrenar un clasificador de imágenes en PyTorch necesitamos cuatro elementos centrales:

  • Un modelo.
  • Un conjunto de datos y un DataLoader.
  • Una función de pérdida.
  • Un optimizador.

Además, necesitamos organizar el loop de entrenamiento y, si queremos hacerlo bien, también un conjunto de validación.

23.4 El papel del dataset

El modelo no aprende en el vacío: aprende a partir de un dataset etiquetado. Cada ejemplo contiene una imagen y una clase asociada.

En PyTorch, un dataset suele devolver pares del estilo:

(imagen, etiqueta)

La imagen se representa como tensor y la etiqueta suele ser un entero que identifica la clase. Por ejemplo, en un problema de 10 clases, las etiquetas pueden ir de 0 a 9.

23.5 ¿Por qué usar DataLoader?

Aunque podríamos recorrer el dataset manualmente, PyTorch ofrece DataLoader para automatizar varias tareas importantes:

  • Agrupar ejemplos en batches.
  • Mezclar el orden de entrenamiento.
  • Iterar cómodamente sobre los datos.
  • Preparar mejor el pipeline de carga.

Esto simplifica mucho el código y hace el entrenamiento más ordenado.

23.6 Batches y actualización por lotes

En la práctica, el modelo no suele actualizarse imagen por imagen, sino por lotes o batches. Un batch puede tener 8, 16, 32 o más ejemplos, según el problema y la memoria disponible.

Trabajar por batches tiene varias ventajas:

  • Aprovecha mejor la computación paralela.
  • Produce gradientes más estables que un ejemplo aislado.
  • Hace el entrenamiento más eficiente.

23.7 Función de pérdida

La función de pérdida mide qué tan equivocada estuvo la red. Es el puente entre la predicción del modelo y la corrección matemática necesaria para ajustar sus parámetros.

En clasificación multiclase, una de las opciones más comunes es nn.CrossEntropyLoss().

criterion = nn.CrossEntropyLoss()

Esta función espera logits como salida del modelo y etiquetas enteras como objetivo.

23.8 ¿Por qué CrossEntropyLoss?

Porque es una función pensada justamente para clasificación multiclase. Castiga al modelo cuando asigna baja puntuación a la clase correcta y alta puntuación a clases incorrectas.

Además, PyTorch la implementa de forma estable y eficiente. Un detalle importante es que, al usar CrossEntropyLoss, la red no debe aplicar softmax manualmente en la salida final.

23.9 El optimizador

Una vez calculada la pérdida, necesitamos un mecanismo para actualizar los pesos. Ese es el papel del optimizador.

Un optimizador muy usado es Adam:

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

Aquí le decimos al optimizador qué parámetros debe modificar y con qué tasa de aprendizaje hacerlo.

23.10 Forward, pérdida, backward y step

El corazón del entrenamiento está en esta secuencia:

  1. Forward: el modelo produce una salida.
  2. Se calcula la pérdida.
  3. Backward: se calcula el gradiente.
  4. Step: el optimizador actualiza los pesos.

En PyTorch, esta lógica suele verse así:

outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

23.11 El paso que muchos olvidan: zero_grad()

Antes de calcular nuevos gradientes, normalmente se limpian los gradientes acumulados del paso anterior:

optimizer.zero_grad()

Esto es importante porque PyTorch acumula gradientes por defecto. Si no los limpiamos, las actualizaciones mezclarían información de varios pasos y producirían un comportamiento incorrecto.

23.12 El loop de entrenamiento

El entrenamiento completo se organiza normalmente en épocas. Cada época recorre todos los batches del conjunto de entrenamiento.

La estructura general es:

  • Recorrer épocas.
  • Dentro de cada época, recorrer batches.
  • Hacer forward, pérdida, backward y actualización.
  • Acumular métricas para monitorear progreso.

Este es el patrón básico que aparece en casi cualquier proyecto serio con PyTorch.

23.13 Modo entrenamiento y modo evaluación

PyTorch distingue entre modo entrenamiento y modo evaluación. Esto es importante porque algunas capas, como Dropout o BatchNorm, se comportan distinto según el contexto.

Para entrenamiento se usa:

model.train()

Y para validación o inferencia:

model.eval()

Aunque nuestra CNN mínima no use todas esas capas, conviene adoptar esta práctica desde el principio.

23.14 Validación

Entrenar bien no significa solo bajar la pérdida de entrenamiento. También necesitamos saber cómo se comporta el modelo sobre datos no usados directamente para ajustar los pesos. Para eso sirve el conjunto de validación.

Durante validación no se actualizan parámetros. Solo se mide el desempeño del modelo para ver si está generalizando o si empieza a sobreajustarse.

23.15 torch.no_grad()

Cuando validamos o inferimos, no necesitamos gradientes. Por eso PyTorch permite envolver ese bloque con:

with torch.no_grad():
    outputs = model(images)

Esto ahorra memoria y acelera el proceso, además de dejar claro que no queremos actualizar nada en esa etapa.

23.16 Accuracy como métrica simple

Además de la pérdida, en clasificación suele medirse la accuracy, es decir, el porcentaje de ejemplos clasificados correctamente.

Una forma habitual de obtener la clase predicha es usando argmax sobre la dimensión de clases:

preds = outputs.argmax(dim=1)

Luego se compara preds con las etiquetas reales y se cuentan los aciertos.

23.17 ¿Cómo saber si la red está aprendiendo?

Algunas señales típicas de que el entrenamiento progresa son:

  • La pérdida de entrenamiento tiende a bajar.
  • La accuracy de entrenamiento tiende a subir.
  • La validación también mejora, al menos durante varias épocas.

Si el entrenamiento mejora pero la validación empeora, puede estar apareciendo overfitting.

23.18 Guardar el mejor modelo

En muchos proyectos conviene guardar el estado del modelo cuando logra su mejor rendimiento de validación. PyTorch lo hace sencillo con torch.save.

torch.save(model.state_dict(), "mejor_modelo.pth")

Luego ese archivo puede cargarse más adelante para continuar entrenamiento o hacer inferencia.

23.19 Un ejemplo didáctico y ejecutable

Para que el código de este tema sea fácil de ejecutar sin depender de descargas externas, usaremos torchvision.datasets.FakeData. Este dataset genera imágenes y etiquetas sintéticas, lo que permite practicar toda la lógica del entrenamiento sin preocuparse por conseguir archivos reales.

El objetivo aquí no es obtener una accuracy impresionante, sino entender el pipeline completo.

23.20 Definición del dataset y DataLoader

Con FakeData podemos crear entrenamiento y validación así:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.ToTensor()

train_dataset = datasets.FakeData(
    size=800,
    image_size=(3, 64, 64),
    num_classes=4,
    transform=transform
)

val_dataset = datasets.FakeData(
    size=200,
    image_size=(3, 64, 64),
    num_classes=4,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Esto ya deja listo el flujo de datos para entrenamiento y validación.

23.21 Modelo para el ejemplo

Podemos reutilizar una CNN pequeña, similar a la del tema anterior, adaptándola a 4 clases:

class CNNPequena(nn.Module):
    def __init__(self, num_clases=4):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_clases)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

23.22 Estructura del entrenamiento por época

Dentro de cada época solemos acumular pérdida y aciertos para luego calcular métricas promedio. Eso permite seguir la evolución del modelo de forma clara.

Conceptualmente, cada época hace dos cosas:

  • Entrena sobre train_loader.
  • Evalúa sobre val_loader.

23.23 Código completo de entrenamiento

Este ejemplo reúne modelo, dataset, optimizador, pérdida, entrenamiento y validación:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class CNNPequena(nn.Module):
    def __init__(self, num_clases=4):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, num_clases)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.ToTensor()

train_dataset = datasets.FakeData(
    size=800,
    image_size=(3, 64, 64),
    num_classes=4,
    transform=transform
)

val_dataset = datasets.FakeData(
    size=200,
    image_size=(3, 64, 64),
    num_classes=4,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

model = CNNPequena(num_clases=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

    avg_train_loss = train_loss / train_total
    train_acc = train_correct / train_total

    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / val_total
    val_acc = val_correct / val_total

    print(
        f"Epoca {epoch+1}/{num_epochs} | "
        f"train_loss={avg_train_loss:.4f} | train_acc={train_acc:.4f} | "
        f"val_loss={avg_val_loss:.4f} | val_acc={val_acc:.4f}"
    )

Este script ya permite ver el entrenamiento completo en funcionamiento y observar cómo evolucionan las métricas a lo largo de las épocas.

23.24 Qué hace cada parte del ejemplo

Vale la pena leer ese código por bloques:

  • Primero se define el modelo.
  • Después se crean dataset y dataloaders.
  • Luego se eligen pérdida y optimizador.
  • Finalmente se ejecuta el loop de épocas con entrenamiento y validación.

Esta estructura es extremadamente común en proyectos con PyTorch, incluso cuando el problema real es mucho más grande.

23.25 Qué limitaciones tiene este ejemplo

Como usa datos sintéticos, no debemos interpretar sus métricas como si fueran el resultado de un problema real. Su valor es didáctico.

Cuando trabajemos con datasets reales, habrá que sumar aspectos como:

  • Transformaciones más específicas.
  • Augmentation realista.
  • Separación cuidadosa de train, validation y test.
  • Control de overfitting.
  • Evaluación más completa.

23.26 Un problema real: clasificar imágenes de CIFAR-10

Ahora sí podemos plantear un ejercicio más cercano a un caso real. En lugar de generar imágenes artificiales, vamos a usar CIFAR-10, uno de los datasets clásicos de visión por computadora.

El enunciado del problema puede formularse así:

Dada una imagen color de 32x32 píxeles, el modelo debe predecir a cuál de estas 10 clases pertenece: avión, auto, pájaro, gato, ciervo, perro, rana, caballo, barco o camión.

CIFAR-10 contiene 60.000 imágenes en total, divididas en 50.000 para entrenamiento y 10.000 para prueba. Cada imagen ya viene etiquetada con una de esas diez categorías.

Este problema ya no es sintético: las imágenes son reales, pequeñas y variadas, por lo que el modelo debe aprender patrones visuales auténticos.

23.27 Código completo con CIFAR-10 y una interfaz en Tkinter

El siguiente programa descarga el dataset, entrena una CNN pequeña y muestra en vivo imágenes del batch actual junto con las predicciones del modelo:

Interfaz de entrenamiento de CIFAR-10 mostrando imágenes del batch actual
import threading
import queue
import tkinter as tk
from tkinter import ttk, messagebox

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from PIL import Image, ImageTk
import numpy as np


# =========================================================
# MODELO
# =========================================================
class CNNPequena(nn.Module):
    def __init__(self, num_clases=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_clases)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv2(x)))   # 16x16 -> 8x8
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# =========================================================
# APP
# =========================================================
class AppEntrenamientoCIFAR10:
    def __init__(self, root):
        self.root = root
        self.root.title("Entrenamiento de CIFAR-10 en vivo")
        self.root.geometry("1450x920")
        self.root.configure(bg="#f4f6f8")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.clases = [
            "avion", "auto", "pajaro", "gato", "ciervo",
            "perro", "rana", "caballo", "barco", "camion"
        ]

        self.model = None
        self.train_loader = None
        self.val_loader = None
        self.optimizer = None
        self.criterion = None

        self.batch_size = 16
        self.num_epochs = 5
        self.learning_rate = 0.001

        self.imagenes_tk = []
        self.labels_imagenes = []

        self.queue_ui = queue.Queue()
        self.stop_requested = False
        self.training_thread = None

        self.crear_interfaz()
        self.root.after(200, self.procesar_cola)

    # -----------------------------------------------------
    # INTERFAZ
    # -----------------------------------------------------
    def crear_interfaz(self):
        marco_principal = tk.Frame(self.root, bg="#f4f6f8")
        marco_principal.pack(fill="both", expand=True, padx=10, pady=10)

        # Panel superior
        panel_superior = tk.Frame(marco_principal, bg="#f4f6f8")
        panel_superior.pack(fill="x", pady=(0, 10))

        titulo = tk.Label(
            panel_superior,
            text="Clasificador de imagenes con CIFAR-10",
            font=("Arial", 22, "bold"),
            bg="#f4f6f8",
            fg="#1f2d3d"
        )
        titulo.pack(anchor="w")

        subtitulo = tk.Label(
            panel_superior,
            text="La aplicacion descarga el dataset, entrena una CNN y muestra en vivo imagenes del batch actual.",
            font=("Arial", 11),
            bg="#f4f6f8",
            fg="#4a6572"
        )
        subtitulo.pack(anchor="w", pady=(4, 0))

        # Panel contenido
        panel_contenido = tk.Frame(marco_principal, bg="#f4f6f8")
        panel_contenido.pack(fill="both", expand=True)

        # Panel izquierdo controles
        panel_izq = tk.Frame(panel_contenido, bg="white", bd=1, relief="solid")
        panel_izq.pack(side="left", fill="y", padx=(0, 10))

        tk.Label(
            panel_izq,
            text="Parametros",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(15, 10))

        frm_params = tk.Frame(panel_izq, bg="white")
        frm_params.pack(fill="x", padx=15, pady=5)

        tk.Label(frm_params, text="Epocas:", font=("Arial", 11), bg="white").grid(row=0, column=0, sticky="w", pady=6)
        self.entry_epochs = ttk.Entry(frm_params, width=10)
        self.entry_epochs.grid(row=0, column=1, sticky="w", pady=6)
        self.entry_epochs.insert(0, "5")

        tk.Label(frm_params, text="Batch size:", font=("Arial", 11), bg="white").grid(row=1, column=0, sticky="w", pady=6)
        self.entry_batch = ttk.Entry(frm_params, width=10)
        self.entry_batch.grid(row=1, column=1, sticky="w", pady=6)
        self.entry_batch.insert(0, "16")

        tk.Label(frm_params, text="Learning rate:", font=("Arial", 11), bg="white").grid(row=2, column=0, sticky="w", pady=6)
        self.entry_lr = ttk.Entry(frm_params, width=10)
        self.entry_lr.grid(row=2, column=1, sticky="w", pady=6)
        self.entry_lr.insert(0, "0.001")

        tk.Label(
            panel_izq,
            text="Dispositivo:",
            font=("Arial", 11, "bold"),
            bg="white"
        ).pack(anchor="w", padx=15, pady=(15, 2))

        self.lbl_device = tk.Label(
            panel_izq,
            text=str(self.device),
            font=("Arial", 11),
            bg="white",
            fg="#1565c0"
        )
        self.lbl_device.pack(anchor="w", padx=15)

        frm_botones = tk.Frame(panel_izq, bg="white")
        frm_botones.pack(fill="x", padx=15, pady=20)

        self.btn_iniciar = ttk.Button(frm_botones, text="Descargar y entrenar", command=self.iniciar_entrenamiento)
        self.btn_iniciar.pack(fill="x", pady=4)

        self.btn_detener = ttk.Button(frm_botones, text="Detener", command=self.detener_entrenamiento)
        self.btn_detener.pack(fill="x", pady=4)

        # Estado
        tk.Label(
            panel_izq,
            text="Estado",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(10, 10))

        self.lbl_estado = tk.Label(
            panel_izq,
            text="Listo para comenzar",
            font=("Arial", 11),
            bg="white",
            fg="#2e7d32",
            justify="left",
            wraplength=260
        )
        self.lbl_estado.pack(anchor="w", padx=15)

        self.barra = ttk.Progressbar(panel_izq, mode="indeterminate")
        self.barra.pack(fill="x", padx=15, pady=15)

        self.lbl_metricas = tk.Label(
            panel_izq,
            text="Epoca: -\nBatch: -\nLoss: -\nAccuracy batch: -\nTrain acc: -\nVal acc: -",
            font=("Consolas", 11),
            bg="white",
            fg="#263238",
            justify="left"
        )
        self.lbl_metricas.pack(anchor="w", padx=15, pady=(0, 20))

        # Panel derecho
        panel_der = tk.Frame(panel_contenido, bg="#f4f6f8")
        panel_der.pack(side="left", fill="both", expand=True)

        # Cuadricula de imagenes
        panel_grid = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
        panel_grid.pack(fill="both", expand=True)

        tk.Label(
            panel_grid,
            text="Imagenes del batch actual",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(15, 10))

        self.frame_imagenes = tk.Frame(panel_grid, bg="white")
        self.frame_imagenes.pack(fill="both", expand=True, padx=10, pady=10)

        for fila in range(2):
            for col in range(2):

                tarjeta = tk.Frame(self.frame_imagenes, bg="#fafafa", bd=1, relief="solid")
                tarjeta.grid(row=fila, column=col, padx=8, pady=8, sticky="nsew")

                self.frame_imagenes.grid_rowconfigure(fila, weight=1)
                self.frame_imagenes.grid_columnconfigure(col, weight=1)

                lbl_img = tk.Label(
                    tarjeta,
                    bg="#eaeff2",
                    text="Sin imagen",
                    font=("Arial", 10),
                    fg="#607d8b"
                )
                lbl_img.pack(padx=10, pady=(10, 6))

                lbl_txt = tk.Label(
                    tarjeta,
                    text="Real: -\nPred: -",
                    font=("Arial", 10),
                    bg="#fafafa",
                    fg="#263238",
                    justify="center"
                )
                lbl_txt.pack(padx=5, pady=(0, 10))

                self.labels_imagenes.append((lbl_img, lbl_txt))

        # Consola
        panel_log = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
        panel_log.pack(fill="both", expand=False, pady=(10, 0))

        tk.Label(
            panel_log,
            text="Salida del entrenamiento",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(15, 10))

        self.text_log = tk.Text(panel_log, height=12, font=("Consolas", 10), bg="#0f172a", fg="#e2e8f0")
        self.text_log.pack(fill="both", expand=True, padx=15, pady=(0, 15))
        self.text_log.config(state="disabled")

    # -----------------------------------------------------
    # UTILIDADES UI
    # -----------------------------------------------------
    def escribir_log(self, texto):
        self.text_log.config(state="normal")
        self.text_log.insert("end", texto + "\n")
        self.text_log.see("end")
        self.text_log.config(state="disabled")

    def actualizar_estado(self, texto, color="#2e7d32"):
        self.lbl_estado.config(text=texto, fg=color)

    def detener_entrenamiento(self):
        self.stop_requested = True
        self.actualizar_estado("Se solicito detener el entrenamiento...", "#c62828")
        self.escribir_log(">> Solicitud de detencion enviada.")

    # -----------------------------------------------------
    # PREPARACION DE DATOS
    # -----------------------------------------------------
    def preparar_datos(self):
        transform = transforms.Compose([
            transforms.ToTensor()
        ])

        train_dataset = datasets.CIFAR10(
            root="./data",
            train=True,
            download=True,
            transform=transform
        )

        val_dataset = datasets.CIFAR10(
            root="./data",
            train=False,
            download=True,
            transform=transform
        )

        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False
        )

    def preparar_modelo(self):
        self.model = CNNPequena(num_clases=10).to(self.device)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

    # -----------------------------------------------------
    # INICIO
    # -----------------------------------------------------
    def iniciar_entrenamiento(self):
        if self.training_thread is not None and self.training_thread.is_alive():
            messagebox.showinfo("Aviso", "Ya hay un entrenamiento en ejecucion.")
            return

        try:
            self.num_epochs = int(self.entry_epochs.get())
            self.batch_size = int(self.entry_batch.get())
            self.learning_rate = float(self.entry_lr.get())

            if self.num_epochs <= 0 or self.batch_size <= 0 or self.learning_rate <= 0:
                raise ValueError
        except ValueError:
            messagebox.showerror("Error", "Revisa los parametros ingresados.")
            return

        self.stop_requested = False
        self.barra.start(10)

        self.actualizar_estado("Preparando dataset y modelo...", "#1565c0")
        self.escribir_log(">> Iniciando proceso...")
        self.escribir_log(f">> Dispositivo detectado: {self.device}")
        self.escribir_log(f">> Epocas: {self.num_epochs} | Batch size: {self.batch_size} | LR: {self.learning_rate}")

        self.training_thread = threading.Thread(target=self.entrenar, daemon=True)
        self.training_thread.start()

    # -----------------------------------------------------
    # ENTRENAMIENTO
    # -----------------------------------------------------
    def entrenar(self):
        try:
            self.queue_ui.put(("estado", "Descargando/cargando CIFAR-10...", "#1565c0"))
            self.preparar_datos()
            self.queue_ui.put(("log", ">> CIFAR-10 listo."))
            self.queue_ui.put(("estado", "Creando modelo...", "#1565c0"))

            self.preparar_modelo()
            self.queue_ui.put(("log", ">> Modelo creado correctamente."))

            total_batches = len(self.train_loader)

            for epoch in range(self.num_epochs):
                if self.stop_requested:
                    self.queue_ui.put(("log", ">> Entrenamiento detenido por el usuario."))
                    self.queue_ui.put(("estado", "Entrenamiento detenido.", "#c62828"))
                    self.queue_ui.put(("fin",))
                    return

                # -------- Entrenamiento --------
                self.model.train()
                train_loss = 0.0
                train_correct = 0
                train_total = 0

                for batch_idx, (images, labels) in enumerate(self.train_loader):
                    if self.stop_requested:
                        self.queue_ui.put(("log", ">> Entrenamiento detenido por el usuario."))
                        self.queue_ui.put(("estado", "Entrenamiento detenido.", "#c62828"))
                        self.queue_ui.put(("fin",))
                        return

                    images = images.to(self.device)
                    labels = labels.to(self.device)

                    self.optimizer.zero_grad()
                    outputs = self.model(images)
                    loss = self.criterion(outputs, labels)
                    loss.backward()
                    self.optimizer.step()

                    preds = outputs.argmax(dim=1)

                    train_loss += loss.item() * images.size(0)
                    train_correct += (preds == labels).sum().item()
                    train_total += labels.size(0)

                    batch_acc = (preds == labels).sum().item() / labels.size(0)
                    avg_train_acc = train_correct / train_total
                    avg_train_loss = train_loss / train_total

                    # actualizar imagenes cada pocos batches para no saturar la UI
                    if batch_idx == 0 or (batch_idx + 1) % 5 == 0:
                        self.queue_ui.put((
                            "imagenes",
                            images[:4].detach().cpu(),
                            labels[:4].detach().cpu(),
                            preds[:4].detach().cpu()
                        ))

                    self.queue_ui.put((
                        "metricas",
                        epoch + 1,
                        self.num_epochs,
                        batch_idx + 1,
                        total_batches,
                        loss.item(),
                        batch_acc,
                        avg_train_loss,
                        avg_train_acc,
                        None
                    ))

                    if (batch_idx + 1) % 20 == 0 or batch_idx == 0:
                        self.queue_ui.put((
                            "log",
                            f"Epoca {epoch+1}/{self.num_epochs} | "
                            f"Batch {batch_idx+1}/{total_batches} | "
                            f"loss={loss.item():.4f} | batch_acc={batch_acc:.4f}"
                        ))

                # -------- Validacion --------
                self.model.eval()
                val_correct = 0
                val_total = 0
                val_loss = 0.0

                with torch.no_grad():
                    for images, labels in self.val_loader:
                        images = images.to(self.device)
                        labels = labels.to(self.device)

                        outputs = self.model(images)
                        loss = self.criterion(outputs, labels)

                        preds = outputs.argmax(dim=1)

                        val_loss += loss.item() * images.size(0)
                        val_correct += (preds == labels).sum().item()
                        val_total += labels.size(0)

                avg_val_loss = val_loss / val_total
                val_acc = val_correct / val_total
                avg_train_loss = train_loss / train_total
                avg_train_acc = train_correct / train_total

                self.queue_ui.put((
                    "metricas",
                    epoch + 1,
                    self.num_epochs,
                    total_batches,
                    total_batches,
                    avg_train_loss,
                    avg_train_acc,
                    avg_train_loss,
                    avg_train_acc,
                    val_acc
                ))

                self.queue_ui.put((
                    "log",
                    f"FIN EPOCA {epoch+1}/{self.num_epochs} | "
                    f"train_loss={avg_train_loss:.4f} | "
                    f"train_acc={avg_train_acc:.4f} | "
                    f"val_loss={avg_val_loss:.4f} | "
                    f"val_acc={val_acc:.4f}"
                ))

            torch.save(self.model.state_dict(), "modelo_cifar10_tkinter.pth")
            self.queue_ui.put(("log", ">> Modelo guardado en modelo_cifar10_tkinter.pth"))
            self.queue_ui.put(("estado", "Entrenamiento finalizado correctamente.", "#2e7d32"))
            self.queue_ui.put(("fin",))

        except Exception as e:
            self.queue_ui.put(("log", f">> ERROR: {str(e)}"))
            self.queue_ui.put(("estado", "Ocurrio un error durante el entrenamiento.", "#c62828"))
            self.queue_ui.put(("fin",))

    # -----------------------------------------------------
    # ACTUALIZACION DE UI
    # -----------------------------------------------------
    def tensor_a_imagen_tk(self, tensor, ancho=220, alto=220):
        arr = tensor.permute(1, 2, 0).numpy()
        arr = np.clip(arr * 255, 0, 255).astype(np.uint8)

        img = Image.fromarray(arr)
        img = img.resize((ancho, alto), Image.Resampling.BICUBIC)
        return ImageTk.PhotoImage(img)

    def mostrar_imagenes_batch(self, images, labels, preds):
        self.imagenes_tk.clear()
        cantidad = min(4, len(images))

        for i in range(cantidad):
            lbl_img, lbl_txt = self.labels_imagenes[i]

            img_tk = self.tensor_a_imagen_tk(images[i])
            self.imagenes_tk.append(img_tk)

            lbl_img.config(image=img_tk, text="")
            lbl_img.image = img_tk

            clase_real = self.clases[int(labels[i].item())]
            clase_pred = self.clases[int(preds[i].item())]

            if clase_real == clase_pred:
                color = "#2e7d32"
            else:
                color = "#c62828"

            lbl_txt.config(
                text=f"Real: {clase_real}\nPred: {clase_pred}",
                fg=color
            )

        for i in range(cantidad, 4):
            lbl_img, lbl_txt = self.labels_imagenes[i]
            lbl_img.config(image="", text="Sin imagen")
            lbl_img.image = None
            lbl_txt.config(text="Real: -\nPred: -", fg="#263238")

    def actualizar_metricas(self, epoch, total_epochs, batch, total_batches,
                            loss, batch_acc, train_loss, train_acc, val_acc):
        texto = (
            f"Epoca: {epoch}/{total_epochs}\n"
            f"Batch: {batch}/{total_batches}\n"
            f"Loss: {loss:.4f}\n"
            f"Accuracy batch: {batch_acc:.4f}\n"
            f"Train acc: {train_acc:.4f}\n"
        )

        if val_acc is None:
            texto += "Val acc: calculando..."
        else:
            texto += f"Val acc: {val_acc:.4f}"

        self.lbl_metricas.config(text=texto)

    def procesar_cola(self):
        ultimo_estado = None
        ultimas_imagenes = None
        ultimas_metricas = None
        fin_recibido = False

        try:
            while True:
                item = self.queue_ui.get_nowait()
                tipo = item[0]

                if tipo == "log":
                    self.escribir_log(item[1])

                elif tipo == "estado":
                    ultimo_estado = item

                elif tipo == "imagenes":
                    ultimas_imagenes = item

                elif tipo == "metricas":
                    ultimas_metricas = item

                elif tipo == "fin":
                    fin_recibido = True

        except queue.Empty:
            pass

        if ultimo_estado is not None:
            _, texto, color = ultimo_estado
            self.actualizar_estado(texto, color)

        if ultimas_imagenes is not None:
            _, images, labels, preds = ultimas_imagenes
            self.mostrar_imagenes_batch(images, labels, preds)

        if ultimas_metricas is not None:
            _, epoch, total_epochs, batch, total_batches, loss, batch_acc, train_loss, train_acc, val_acc = ultimas_metricas
            self.actualizar_metricas(epoch, total_epochs, batch, total_batches, loss, batch_acc, train_loss, train_acc, val_acc)

        if fin_recibido:
            self.barra.stop()

        self.root.after(50, self.procesar_cola)


# =========================================================
# MAIN
# =========================================================
if __name__ == "__main__":
    root = tk.Tk()
    app = AppEntrenamientoCIFAR10(root)
    root.mainloop()

Aquí ya aparece una situación típica de trabajo real: el dataset se descarga automáticamente, se construyen los DataLoader, se entrena la red y además se visualiza qué está ocurriendo mientras avanza el proceso.

23.28 Problema final: seleccionar una imagen cualquiera y clasificarla

Una vez entrenado el modelo del punto anterior, podemos plantear un ejercicio final muy natural: permitir que el usuario elija una imagen de su computadora y pedirle a la red que intente clasificarla.

El enunciado del problema puede formularse así:

Dada una imagen seleccionada manualmente por el usuario, el sistema debe cargar el modelo entrenado con CIFAR-10, preprocesar la imagen y devolver la clase predicha junto con la confianza de la predicción.

Este ejemplo conecta entrenamiento e inferencia. Primero entrenamos y guardamos el archivo modelo_cifar10_tkinter.pth. Después lo reutilizamos para clasificar imágenes nuevas.

Interfaz para seleccionar una imagen y clasificarla con el modelo entrenado de CIFAR-10

El siguiente programa implementa esa etapa final:

import os
import tkinter as tk
from tkinter import ttk, filedialog, messagebox

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageTk


# =========================================================
# MODELO
# =========================================================
class CNNPequena(nn.Module):
    def __init__(self, num_clases=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_clases)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv2(x)))   # 16x16 -> 8x8
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# =========================================================
# APP
# =========================================================
class AppPrediccionCIFAR10:
    def __init__(self, root):
        self.root = root
        self.root.title("Predicción de imágenes con modelo CIFAR-10")
        self.root.geometry("1100x760")
        self.root.configure(bg="#f4f6f8")

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.ruta_modelo = "modelo_cifar10_tkinter.pth"

        self.clases = [
            "avion", "auto", "pajaro", "gato", "ciervo",
            "perro", "rana", "caballo", "barco", "camion"
        ]

        self.modelo = None
        self.imagen_original_pil = None
        self.imagen_preview_tk = None
        self.ruta_imagen = None

        self.transformacion = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor()
        ])

        self.crear_interfaz()
        self.cargar_modelo()

    # -----------------------------------------------------
    # INTERFAZ
    # -----------------------------------------------------
    def crear_interfaz(self):
        contenedor = tk.Frame(self.root, bg="#f4f6f8")
        contenedor.pack(fill="both", expand=True, padx=12, pady=12)

        titulo = tk.Label(
            contenedor,
            text="Clasificación de una imagen usando un modelo entrenado",
            font=("Arial", 22, "bold"),
            bg="#f4f6f8",
            fg="#1f2d3d"
        )
        titulo.pack(anchor="w")

        subtitulo = tk.Label(
            contenedor,
            text="Selecciona una imagen desde tu computadora y el modelo intentará clasificarla en una de las clases de CIFAR-10.",
            font=("Arial", 11),
            bg="#f4f6f8",
            fg="#546e7a"
        )
        subtitulo.pack(anchor="w", pady=(4, 12))

        cuerpo = tk.Frame(contenedor, bg="#f4f6f8")
        cuerpo.pack(fill="both", expand=True)

        # Panel izquierdo
        panel_izq = tk.Frame(cuerpo, bg="white", bd=1, relief="solid")
        panel_izq.pack(side="left", fill="y", padx=(0, 10))

        tk.Label(
            panel_izq,
            text="Modelo",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(15, 10))

        self.lbl_estado_modelo = tk.Label(
            panel_izq,
            text="Cargando modelo...",
            font=("Arial", 11),
            bg="white",
            fg="#1565c0",
            justify="left",
            wraplength=280
        )
        self.lbl_estado_modelo.pack(anchor="w", padx=15)

        tk.Label(
            panel_izq,
            text=f"Dispositivo: {self.device}",
            font=("Arial", 11),
            bg="white",
            fg="#37474f"
        ).pack(anchor="w", padx=15, pady=(10, 20))

        ttk.Button(
            panel_izq,
            text="Seleccionar imagen",
            command=self.seleccionar_imagen
        ).pack(fill="x", padx=15, pady=5)

        ttk.Button(
            panel_izq,
            text="Predecir clase",
            command=self.predecir_imagen
        ).pack(fill="x", padx=15, pady=5)

        ttk.Button(
            panel_izq,
            text="Limpiar",
            command=self.limpiar
        ).pack(fill="x", padx=15, pady=5)

        tk.Label(
            panel_izq,
            text="Ruta de la imagen",
            font=("Arial", 14, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(25, 8))

        self.lbl_ruta = tk.Label(
            panel_izq,
            text="Todavía no se seleccionó ninguna imagen.",
            font=("Arial", 10),
            bg="white",
            fg="#455a64",
            justify="left",
            wraplength=280
        )
        self.lbl_ruta.pack(anchor="w", padx=15)

        tk.Label(
            panel_izq,
            text="Resultado",
            font=("Arial", 14, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(25, 8))

        self.lbl_resultado = tk.Label(
            panel_izq,
            text="Clase predicha: -",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#2e7d32",
            justify="left"
        )
        self.lbl_resultado.pack(anchor="w", padx=15)

        self.lbl_confianza = tk.Label(
            panel_izq,
            text="Confianza: -",
            font=("Arial", 11),
            bg="white",
            fg="#37474f",
            justify="left"
        )
        self.lbl_confianza.pack(anchor="w", padx=15, pady=(8, 15))

        # Panel derecho
        panel_der = tk.Frame(cuerpo, bg="#f4f6f8")
        panel_der.pack(side="left", fill="both", expand=True)

        # Vista imagen
        panel_imagen = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
        panel_imagen.pack(fill="both", expand=True)

        tk.Label(
            panel_imagen,
            text="Imagen seleccionada",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(15, 10))

        self.canvas_imagen = tk.Label(
            panel_imagen,
            bg="#eaeff2",
            width=500,
            height=350
        )
        self.canvas_imagen.pack(padx=15, pady=15)

        self.lbl_info_preproceso = tk.Label(
            panel_imagen,
            text="La imagen será redimensionada a 32x32 y convertida a tensor para que la red pueda procesarla.",
            font=("Arial", 11),
            bg="white",
            fg="#546e7a"
        )
        self.lbl_info_preproceso.pack(anchor="w", padx=15, pady=(0, 15))

        # Probabilidades
        panel_probs = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
        panel_probs.pack(fill="both", expand=False, pady=(10, 0))

        tk.Label(
            panel_probs,
            text="Probabilidades por clase",
            font=("Arial", 16, "bold"),
            bg="white",
            fg="#1f2d3d"
        ).pack(anchor="w", padx=15, pady=(15, 10))

        self.text_probs = tk.Text(
            panel_probs,
            height=12,
            font=("Consolas", 11),
            bg="#0f172a",
            fg="#e2e8f0"
        )
        self.text_probs.pack(fill="both", expand=True, padx=15, pady=(0, 15))
        self.text_probs.config(state="disabled")

    # -----------------------------------------------------
    # MODELO
    # -----------------------------------------------------
    def cargar_modelo(self):
        if not os.path.exists(self.ruta_modelo):
            self.lbl_estado_modelo.config(
                text=f"No se encontró el archivo {self.ruta_modelo}",
                fg="#c62828"
            )
            messagebox.showerror(
                "Error",
                f"No se encontró el archivo del modelo:\n{self.ruta_modelo}\n\n"
                "Asegúrate de ejecutar primero la aplicación de entrenamiento."
            )
            return

        try:
            self.modelo = CNNPequena(num_clases=10).to(self.device)
            state_dict = torch.load(self.ruta_modelo, map_location=self.device)
            self.modelo.load_state_dict(state_dict)
            self.modelo.eval()

            self.lbl_estado_modelo.config(
                text=f"Modelo cargado correctamente desde:\n{self.ruta_modelo}",
                fg="#2e7d32"
            )
        except Exception as e:
            self.lbl_estado_modelo.config(
                text="Error al cargar el modelo.",
                fg="#c62828"
            )
            messagebox.showerror("Error", f"No se pudo cargar el modelo.\n\n{e}")

    # -----------------------------------------------------
    # IMAGEN
    # -----------------------------------------------------
    def seleccionar_imagen(self):
        ruta = filedialog.askopenfilename(
            title="Seleccionar imagen",
            filetypes=[
                ("Imágenes", "*.png *.jpg *.jpeg *.bmp *.webp"),
                ("Todos los archivos", "*.*")
            ]
        )

        if not ruta:
            return

        try:
            imagen = Image.open(ruta).convert("RGB")
            self.imagen_original_pil = imagen
            self.ruta_imagen = ruta

            self.mostrar_preview(imagen)

            self.lbl_ruta.config(text=ruta)
            self.lbl_resultado.config(text="Clase predicha: -", fg="#2e7d32")
            self.lbl_confianza.config(text="Confianza: -")
            self.limpiar_probabilidades()

        except Exception as e:
            messagebox.showerror("Error", f"No se pudo abrir la imagen.\n\n{e}")

    def mostrar_preview(self, imagen_pil):
        copia = imagen_pil.copy()
        copia.thumbnail((500, 350))
        self.imagen_preview_tk = ImageTk.PhotoImage(copia)
        self.canvas_imagen.config(image=self.imagen_preview_tk)

    def preparar_tensor(self, imagen_pil):
        tensor = self.transformacion(imagen_pil)
        tensor = tensor.unsqueeze(0)  # agregar dimensión batch
        return tensor.to(self.device)

    # -----------------------------------------------------
    # PREDICCIÓN
    # -----------------------------------------------------
    def predecir_imagen(self):
        if self.modelo is None:
            messagebox.showerror("Error", "El modelo no está cargado.")
            return

        if self.imagen_original_pil is None:
            messagebox.showwarning("Aviso", "Primero selecciona una imagen.")
            return

        try:
            tensor = self.preparar_tensor(self.imagen_original_pil)

            with torch.no_grad():
                salida = self.modelo(tensor)
                probabilidades = torch.softmax(salida, dim=1)
                indice = probabilidades.argmax(dim=1).item()
                confianza = probabilidades[0, indice].item()

            clase_predicha = self.clases[indice]

            self.lbl_resultado.config(
                text=f"Clase predicha: {clase_predicha}",
                fg="#1565c0"
            )
            self.lbl_confianza.config(
                text=f"Confianza: {confianza * 100:.2f}%"
            )

            self.mostrar_probabilidades(probabilidades[0].cpu())

        except Exception as e:
            messagebox.showerror("Error", f"No se pudo realizar la predicción.\n\n{e}")

    # -----------------------------------------------------
    # SALIDA
    # -----------------------------------------------------
    def mostrar_probabilidades(self, probs):
        pares = list(zip(self.clases, probs.tolist()))
        pares.sort(key=lambda x: x[1], reverse=True)

        self.text_probs.config(state="normal")
        self.text_probs.delete("1.0", "end")

        self.text_probs.insert("end", "Ranking de clases:\n\n")
        for clase, prob in pares:
            self.text_probs.insert("end", f"{clase:<10} -> {prob * 100:6.2f}%\n")

        self.text_probs.config(state="disabled")

    def limpiar_probabilidades(self):
        self.text_probs.config(state="normal")
        self.text_probs.delete("1.0", "end")
        self.text_probs.config(state="disabled")

    def limpiar(self):
        self.imagen_original_pil = None
        self.imagen_preview_tk = None
        self.ruta_imagen = None

        self.canvas_imagen.config(image="")
        self.lbl_ruta.config(text="Todavía no se seleccionó ninguna imagen.")
        self.lbl_resultado.config(text="Clase predicha: -", fg="#2e7d32")
        self.lbl_confianza.config(text="Confianza: -")
        self.limpiar_probabilidades()


# =========================================================
# MAIN
# =========================================================
if __name__ == "__main__":
    root = tk.Tk()
    app = AppPrediccionCIFAR10(root)
    root.mainloop()

Este programa no vuelve a entrenar la red. Su trabajo es cargar los pesos ya aprendidos, preparar una imagen nueva, ejecutar una pasada hacia adelante y mostrar tanto la clase más probable como el ranking completo de probabilidades.

23.29 De dónde salen los datos y dónde se guardan

Los datos salen de torchvision.datasets.CIFAR10. Esa clase conoce la URL oficial del dataset y, cuando usamos download=True, se encarga de bajarlo si todavía no existe en nuestra máquina.

En este programa se usa:

train_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

Eso significa que el dataset se guarda dentro de una carpeta llamada data ubicada en el directorio desde el cual ejecutamos el script.

Si ejecutamos el programa desde la carpeta del proyecto, la estructura típica que veremos será algo parecida a esta:

proyecto/
  data/
    cifar-10-batches-py/
    cifar-10-python.tar.gz

El archivo comprimido suele descargarse como cifar-10-python.tar.gz. Luego torchvision lo descomprime automáticamente y deja disponible la carpeta cifar-10-batches-py, que contiene los archivos internos del dataset.

Es decir, normalmente no tenemos que descomprimir nada a mano. La librería hace ese trabajo por nosotros si el dataset todavía no está preparado.

23.30 Qué está resolviendo exactamente este programa

Este ejemplo resuelve un problema de clasificación multiclase supervisada. Cada imagen entra a la CNN como un tensor de 3 canales y tamaño 32x32, y la red produce 10 valores de salida, uno por cada clase posible.

Durante el entrenamiento, CrossEntropyLoss compara esos 10 valores con la etiqueta correcta. Después, Adam ajusta los pesos para que la clase correcta tienda a recibir una puntuación más alta en futuras iteraciones.

La interfaz en Tkinter agrega algo muy útil desde el punto de vista didáctico: permite ver ejemplos concretos del batch actual, la clase real, la clase predicha y las métricas acumuladas mientras la red aprende.

Por eso este ejemplo ya se parece bastante más a una aplicación de trabajo real que al caso anterior con FakeData.

23.31 Errores comunes al entrenar

Algunos errores frecuentes son:

  • Olvidar optimizer.zero_grad().
  • No mover tensores y modelo al mismo dispositivo.
  • Usar mal las dimensiones de entrada.
  • Aplicar softmax antes de CrossEntropyLoss.
  • No usar model.eval() ni torch.no_grad() en validación.
  • Mirar solo entrenamiento y no validar.

Estos errores son muy comunes incluso en personas que ya entienden la teoría general.

23.32 Qué debes recordar de este tema

  • Entrenar un clasificador de imágenes implica ajustar los pesos del modelo a partir del error sobre ejemplos etiquetados.
  • Los componentes básicos son modelo, dataset, DataLoader, función de pérdida y optimizador.
  • En clasificación multiclase, nn.CrossEntropyLoss() es una elección muy habitual.
  • El loop de entrenamiento típico incluye forward, cálculo de pérdida, backward y actualización del optimizador.
  • La validación debe ejecutarse en modo evaluación y sin gradientes.
  • La accuracy es una métrica simple y útil, pero debe leerse junto con la pérdida.
  • Antes de trabajar con datasets reales, conviene dominar esta estructura básica con ejemplos pequeños y controlados.

23.33 Conclusión

Entrenar un clasificador de imágenes en PyTorch es el punto donde las piezas principales del pipeline de Deep Learning empiezan a encajar de verdad. La arquitectura deja de ser un objeto estático y se convierte en un sistema que aprende, se equivoca, corrige sus pesos y mejora progresivamente a partir de los datos.

Lo importante en esta etapa es dominar la lógica general: cómo se organiza el loop, qué rol cumple la pérdida, cómo interviene el optimizador y por qué la validación es necesaria para interpretar si el modelo realmente está aprendiendo algo útil.

En el próximo tema ampliaremos esta mirada centrándonos en la evaluación de modelos de visión por computadora, para entender con más precisión cómo medir el desempeño de una red más allá de observar solo la pérdida o la accuracy de entrenamiento.