En el tema anterior construimos una CNN en PyTorch. Definimos sus capas, entendimos el papel de nn.Module y verificamos que el modelo podía recibir imágenes y producir una salida con la forma esperada. Pero todavía faltaba lo más importante: hacer que aprenda.
Entrenar un clasificador de imágenes significa ajustar los pesos del modelo para que, a partir de ejemplos etiquetados, aprenda a asociar patrones visuales con clases correctas. Esa etapa requiere mucho más que una arquitectura: necesitamos datos, una función de pérdida, un optimizador y un ciclo de entrenamiento bien planteado.
En este tema veremos la lógica completa del entrenamiento en PyTorch. La idea es comprender no solo qué líneas de código hay que escribir, sino también qué está ocurriendo conceptualmente en cada paso.
Entrenar un modelo significa exponerlo repetidamente a ejemplos de entrada y comparar sus predicciones con las etiquetas correctas. A partir de ese error, el algoritmo ajusta los pesos para que futuras predicciones sean mejores.
En clasificación de imágenes, el proceso general es:
Este ciclo se repite muchas veces hasta que el modelo mejora su desempeño.
Para entrenar un clasificador de imágenes en PyTorch necesitamos cuatro elementos centrales:
DataLoader.Además, necesitamos organizar el loop de entrenamiento y, si queremos hacerlo bien, también un conjunto de validación.
El modelo no aprende en el vacío: aprende a partir de un dataset etiquetado. Cada ejemplo contiene una imagen y una clase asociada.
En PyTorch, un dataset suele devolver pares del estilo:
(imagen, etiqueta)
La imagen se representa como tensor y la etiqueta suele ser un entero que identifica la clase. Por ejemplo, en un problema de 10 clases, las etiquetas pueden ir de 0 a 9.
DataLoader?Aunque podríamos recorrer el dataset manualmente, PyTorch ofrece DataLoader para automatizar varias tareas importantes:
Esto simplifica mucho el código y hace el entrenamiento más ordenado.
En la práctica, el modelo no suele actualizarse imagen por imagen, sino por lotes o batches. Un batch puede tener 8, 16, 32 o más ejemplos, según el problema y la memoria disponible.
Trabajar por batches tiene varias ventajas:
La función de pérdida mide qué tan equivocada estuvo la red. Es el puente entre la predicción del modelo y la corrección matemática necesaria para ajustar sus parámetros.
En clasificación multiclase, una de las opciones más comunes es nn.CrossEntropyLoss().
criterion = nn.CrossEntropyLoss()
Esta función espera logits como salida del modelo y etiquetas enteras como objetivo.
Porque es una función pensada justamente para clasificación multiclase. Castiga al modelo cuando asigna baja puntuación a la clase correcta y alta puntuación a clases incorrectas.
Además, PyTorch la implementa de forma estable y eficiente. Un detalle importante es que, al usar CrossEntropyLoss, la red no debe aplicar softmax manualmente en la salida final.
Una vez calculada la pérdida, necesitamos un mecanismo para actualizar los pesos. Ese es el papel del optimizador.
Un optimizador muy usado es Adam:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
Aquí le decimos al optimizador qué parámetros debe modificar y con qué tasa de aprendizaje hacerlo.
El corazón del entrenamiento está en esta secuencia:
En PyTorch, esta lógica suele verse así:
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
zero_grad()Antes de calcular nuevos gradientes, normalmente se limpian los gradientes acumulados del paso anterior:
optimizer.zero_grad()
Esto es importante porque PyTorch acumula gradientes por defecto. Si no los limpiamos, las actualizaciones mezclarían información de varios pasos y producirían un comportamiento incorrecto.
El entrenamiento completo se organiza normalmente en épocas. Cada época recorre todos los batches del conjunto de entrenamiento.
La estructura general es:
Este es el patrón básico que aparece en casi cualquier proyecto serio con PyTorch.
PyTorch distingue entre modo entrenamiento y modo evaluación. Esto es importante porque algunas capas, como Dropout o BatchNorm, se comportan distinto según el contexto.
Para entrenamiento se usa:
model.train()
Y para validación o inferencia:
model.eval()
Aunque nuestra CNN mínima no use todas esas capas, conviene adoptar esta práctica desde el principio.
Entrenar bien no significa solo bajar la pérdida de entrenamiento. También necesitamos saber cómo se comporta el modelo sobre datos no usados directamente para ajustar los pesos. Para eso sirve el conjunto de validación.
Durante validación no se actualizan parámetros. Solo se mide el desempeño del modelo para ver si está generalizando o si empieza a sobreajustarse.
torch.no_grad()Cuando validamos o inferimos, no necesitamos gradientes. Por eso PyTorch permite envolver ese bloque con:
with torch.no_grad():
outputs = model(images)
Esto ahorra memoria y acelera el proceso, además de dejar claro que no queremos actualizar nada en esa etapa.
Además de la pérdida, en clasificación suele medirse la accuracy, es decir, el porcentaje de ejemplos clasificados correctamente.
Una forma habitual de obtener la clase predicha es usando argmax sobre la dimensión de clases:
preds = outputs.argmax(dim=1)
Luego se compara preds con las etiquetas reales y se cuentan los aciertos.
Algunas señales típicas de que el entrenamiento progresa son:
Si el entrenamiento mejora pero la validación empeora, puede estar apareciendo overfitting.
En muchos proyectos conviene guardar el estado del modelo cuando logra su mejor rendimiento de validación. PyTorch lo hace sencillo con torch.save.
torch.save(model.state_dict(), "mejor_modelo.pth")
Luego ese archivo puede cargarse más adelante para continuar entrenamiento o hacer inferencia.
Para que el código de este tema sea fácil de ejecutar sin depender de descargas externas, usaremos torchvision.datasets.FakeData. Este dataset genera imágenes y etiquetas sintéticas, lo que permite practicar toda la lógica del entrenamiento sin preocuparse por conseguir archivos reales.
El objetivo aquí no es obtener una accuracy impresionante, sino entender el pipeline completo.
Con FakeData podemos crear entrenamiento y validación así:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.ToTensor()
train_dataset = datasets.FakeData(
size=800,
image_size=(3, 64, 64),
num_classes=4,
transform=transform
)
val_dataset = datasets.FakeData(
size=200,
image_size=(3, 64, 64),
num_classes=4,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
Esto ya deja listo el flujo de datos para entrenamiento y validación.
Podemos reutilizar una CNN pequeña, similar a la del tema anterior, adaptándola a 4 clases:
class CNNPequena(nn.Module):
def __init__(self, num_clases=4):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 16 * 16, 128)
self.fc2 = nn.Linear(128, num_clases)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
Dentro de cada época solemos acumular pérdida y aciertos para luego calcular métricas promedio. Eso permite seguir la evolución del modelo de forma clara.
Conceptualmente, cada época hace dos cosas:
train_loader.val_loader.Este ejemplo reúne modelo, dataset, optimizador, pérdida, entrenamiento y validación:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class CNNPequena(nn.Module):
def __init__(self, num_clases=4):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 16 * 16, 128)
self.fc2 = nn.Linear(128, num_clases)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.ToTensor()
train_dataset = datasets.FakeData(
size=800,
image_size=(3, 64, 64),
num_classes=4,
transform=transform
)
val_dataset = datasets.FakeData(
size=200,
image_size=(3, 64, 64),
num_classes=4,
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
model = CNNPequena(num_clases=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 5
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
preds = outputs.argmax(dim=1)
train_correct += (preds == labels).sum().item()
train_total += labels.size(0)
avg_train_loss = train_loss / train_total
train_acc = train_correct / train_total
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
preds = outputs.argmax(dim=1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
avg_val_loss = val_loss / val_total
val_acc = val_correct / val_total
print(
f"Epoca {epoch+1}/{num_epochs} | "
f"train_loss={avg_train_loss:.4f} | train_acc={train_acc:.4f} | "
f"val_loss={avg_val_loss:.4f} | val_acc={val_acc:.4f}"
)
Este script ya permite ver el entrenamiento completo en funcionamiento y observar cómo evolucionan las métricas a lo largo de las épocas.
Vale la pena leer ese código por bloques:
Esta estructura es extremadamente común en proyectos con PyTorch, incluso cuando el problema real es mucho más grande.
Como usa datos sintéticos, no debemos interpretar sus métricas como si fueran el resultado de un problema real. Su valor es didáctico.
Cuando trabajemos con datasets reales, habrá que sumar aspectos como:
Ahora sí podemos plantear un ejercicio más cercano a un caso real. En lugar de generar imágenes artificiales, vamos a usar CIFAR-10, uno de los datasets clásicos de visión por computadora.
El enunciado del problema puede formularse así:
CIFAR-10 contiene 60.000 imágenes en total, divididas en 50.000 para entrenamiento y 10.000 para prueba. Cada imagen ya viene etiquetada con una de esas diez categorías.
Este problema ya no es sintético: las imágenes son reales, pequeñas y variadas, por lo que el modelo debe aprender patrones visuales auténticos.
El siguiente programa descarga el dataset, entrena una CNN pequeña y muestra en vivo imágenes del batch actual junto con las predicciones del modelo:
import threading
import queue
import tkinter as tk
from tkinter import ttk, messagebox
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from PIL import Image, ImageTk
import numpy as np
# =========================================================
# MODELO
# =========================================================
class CNNPequena(nn.Module):
def __init__(self, num_clases=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, num_clases)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 32x32 -> 16x16
x = self.pool(F.relu(self.conv2(x))) # 16x16 -> 8x8
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# =========================================================
# APP
# =========================================================
class AppEntrenamientoCIFAR10:
def __init__(self, root):
self.root = root
self.root.title("Entrenamiento de CIFAR-10 en vivo")
self.root.geometry("1450x920")
self.root.configure(bg="#f4f6f8")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.clases = [
"avion", "auto", "pajaro", "gato", "ciervo",
"perro", "rana", "caballo", "barco", "camion"
]
self.model = None
self.train_loader = None
self.val_loader = None
self.optimizer = None
self.criterion = None
self.batch_size = 16
self.num_epochs = 5
self.learning_rate = 0.001
self.imagenes_tk = []
self.labels_imagenes = []
self.queue_ui = queue.Queue()
self.stop_requested = False
self.training_thread = None
self.crear_interfaz()
self.root.after(200, self.procesar_cola)
# -----------------------------------------------------
# INTERFAZ
# -----------------------------------------------------
def crear_interfaz(self):
marco_principal = tk.Frame(self.root, bg="#f4f6f8")
marco_principal.pack(fill="both", expand=True, padx=10, pady=10)
# Panel superior
panel_superior = tk.Frame(marco_principal, bg="#f4f6f8")
panel_superior.pack(fill="x", pady=(0, 10))
titulo = tk.Label(
panel_superior,
text="Clasificador de imagenes con CIFAR-10",
font=("Arial", 22, "bold"),
bg="#f4f6f8",
fg="#1f2d3d"
)
titulo.pack(anchor="w")
subtitulo = tk.Label(
panel_superior,
text="La aplicacion descarga el dataset, entrena una CNN y muestra en vivo imagenes del batch actual.",
font=("Arial", 11),
bg="#f4f6f8",
fg="#4a6572"
)
subtitulo.pack(anchor="w", pady=(4, 0))
# Panel contenido
panel_contenido = tk.Frame(marco_principal, bg="#f4f6f8")
panel_contenido.pack(fill="both", expand=True)
# Panel izquierdo controles
panel_izq = tk.Frame(panel_contenido, bg="white", bd=1, relief="solid")
panel_izq.pack(side="left", fill="y", padx=(0, 10))
tk.Label(
panel_izq,
text="Parametros",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(15, 10))
frm_params = tk.Frame(panel_izq, bg="white")
frm_params.pack(fill="x", padx=15, pady=5)
tk.Label(frm_params, text="Epocas:", font=("Arial", 11), bg="white").grid(row=0, column=0, sticky="w", pady=6)
self.entry_epochs = ttk.Entry(frm_params, width=10)
self.entry_epochs.grid(row=0, column=1, sticky="w", pady=6)
self.entry_epochs.insert(0, "5")
tk.Label(frm_params, text="Batch size:", font=("Arial", 11), bg="white").grid(row=1, column=0, sticky="w", pady=6)
self.entry_batch = ttk.Entry(frm_params, width=10)
self.entry_batch.grid(row=1, column=1, sticky="w", pady=6)
self.entry_batch.insert(0, "16")
tk.Label(frm_params, text="Learning rate:", font=("Arial", 11), bg="white").grid(row=2, column=0, sticky="w", pady=6)
self.entry_lr = ttk.Entry(frm_params, width=10)
self.entry_lr.grid(row=2, column=1, sticky="w", pady=6)
self.entry_lr.insert(0, "0.001")
tk.Label(
panel_izq,
text="Dispositivo:",
font=("Arial", 11, "bold"),
bg="white"
).pack(anchor="w", padx=15, pady=(15, 2))
self.lbl_device = tk.Label(
panel_izq,
text=str(self.device),
font=("Arial", 11),
bg="white",
fg="#1565c0"
)
self.lbl_device.pack(anchor="w", padx=15)
frm_botones = tk.Frame(panel_izq, bg="white")
frm_botones.pack(fill="x", padx=15, pady=20)
self.btn_iniciar = ttk.Button(frm_botones, text="Descargar y entrenar", command=self.iniciar_entrenamiento)
self.btn_iniciar.pack(fill="x", pady=4)
self.btn_detener = ttk.Button(frm_botones, text="Detener", command=self.detener_entrenamiento)
self.btn_detener.pack(fill="x", pady=4)
# Estado
tk.Label(
panel_izq,
text="Estado",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(10, 10))
self.lbl_estado = tk.Label(
panel_izq,
text="Listo para comenzar",
font=("Arial", 11),
bg="white",
fg="#2e7d32",
justify="left",
wraplength=260
)
self.lbl_estado.pack(anchor="w", padx=15)
self.barra = ttk.Progressbar(panel_izq, mode="indeterminate")
self.barra.pack(fill="x", padx=15, pady=15)
self.lbl_metricas = tk.Label(
panel_izq,
text="Epoca: -\nBatch: -\nLoss: -\nAccuracy batch: -\nTrain acc: -\nVal acc: -",
font=("Consolas", 11),
bg="white",
fg="#263238",
justify="left"
)
self.lbl_metricas.pack(anchor="w", padx=15, pady=(0, 20))
# Panel derecho
panel_der = tk.Frame(panel_contenido, bg="#f4f6f8")
panel_der.pack(side="left", fill="both", expand=True)
# Cuadricula de imagenes
panel_grid = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
panel_grid.pack(fill="both", expand=True)
tk.Label(
panel_grid,
text="Imagenes del batch actual",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(15, 10))
self.frame_imagenes = tk.Frame(panel_grid, bg="white")
self.frame_imagenes.pack(fill="both", expand=True, padx=10, pady=10)
for fila in range(2):
for col in range(2):
tarjeta = tk.Frame(self.frame_imagenes, bg="#fafafa", bd=1, relief="solid")
tarjeta.grid(row=fila, column=col, padx=8, pady=8, sticky="nsew")
self.frame_imagenes.grid_rowconfigure(fila, weight=1)
self.frame_imagenes.grid_columnconfigure(col, weight=1)
lbl_img = tk.Label(
tarjeta,
bg="#eaeff2",
text="Sin imagen",
font=("Arial", 10),
fg="#607d8b"
)
lbl_img.pack(padx=10, pady=(10, 6))
lbl_txt = tk.Label(
tarjeta,
text="Real: -\nPred: -",
font=("Arial", 10),
bg="#fafafa",
fg="#263238",
justify="center"
)
lbl_txt.pack(padx=5, pady=(0, 10))
self.labels_imagenes.append((lbl_img, lbl_txt))
# Consola
panel_log = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
panel_log.pack(fill="both", expand=False, pady=(10, 0))
tk.Label(
panel_log,
text="Salida del entrenamiento",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(15, 10))
self.text_log = tk.Text(panel_log, height=12, font=("Consolas", 10), bg="#0f172a", fg="#e2e8f0")
self.text_log.pack(fill="both", expand=True, padx=15, pady=(0, 15))
self.text_log.config(state="disabled")
# -----------------------------------------------------
# UTILIDADES UI
# -----------------------------------------------------
def escribir_log(self, texto):
self.text_log.config(state="normal")
self.text_log.insert("end", texto + "\n")
self.text_log.see("end")
self.text_log.config(state="disabled")
def actualizar_estado(self, texto, color="#2e7d32"):
self.lbl_estado.config(text=texto, fg=color)
def detener_entrenamiento(self):
self.stop_requested = True
self.actualizar_estado("Se solicito detener el entrenamiento...", "#c62828")
self.escribir_log(">> Solicitud de detencion enviada.")
# -----------------------------------------------------
# PREPARACION DE DATOS
# -----------------------------------------------------
def preparar_datos(self):
transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transform
)
val_dataset = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transform
)
self.train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True
)
self.val_loader = DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False
)
def preparar_modelo(self):
self.model = CNNPequena(num_clases=10).to(self.device)
self.criterion = nn.CrossEntropyLoss()
self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
# -----------------------------------------------------
# INICIO
# -----------------------------------------------------
def iniciar_entrenamiento(self):
if self.training_thread is not None and self.training_thread.is_alive():
messagebox.showinfo("Aviso", "Ya hay un entrenamiento en ejecucion.")
return
try:
self.num_epochs = int(self.entry_epochs.get())
self.batch_size = int(self.entry_batch.get())
self.learning_rate = float(self.entry_lr.get())
if self.num_epochs <= 0 or self.batch_size <= 0 or self.learning_rate <= 0:
raise ValueError
except ValueError:
messagebox.showerror("Error", "Revisa los parametros ingresados.")
return
self.stop_requested = False
self.barra.start(10)
self.actualizar_estado("Preparando dataset y modelo...", "#1565c0")
self.escribir_log(">> Iniciando proceso...")
self.escribir_log(f">> Dispositivo detectado: {self.device}")
self.escribir_log(f">> Epocas: {self.num_epochs} | Batch size: {self.batch_size} | LR: {self.learning_rate}")
self.training_thread = threading.Thread(target=self.entrenar, daemon=True)
self.training_thread.start()
# -----------------------------------------------------
# ENTRENAMIENTO
# -----------------------------------------------------
def entrenar(self):
try:
self.queue_ui.put(("estado", "Descargando/cargando CIFAR-10...", "#1565c0"))
self.preparar_datos()
self.queue_ui.put(("log", ">> CIFAR-10 listo."))
self.queue_ui.put(("estado", "Creando modelo...", "#1565c0"))
self.preparar_modelo()
self.queue_ui.put(("log", ">> Modelo creado correctamente."))
total_batches = len(self.train_loader)
for epoch in range(self.num_epochs):
if self.stop_requested:
self.queue_ui.put(("log", ">> Entrenamiento detenido por el usuario."))
self.queue_ui.put(("estado", "Entrenamiento detenido.", "#c62828"))
self.queue_ui.put(("fin",))
return
# -------- Entrenamiento --------
self.model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for batch_idx, (images, labels) in enumerate(self.train_loader):
if self.stop_requested:
self.queue_ui.put(("log", ">> Entrenamiento detenido por el usuario."))
self.queue_ui.put(("estado", "Entrenamiento detenido.", "#c62828"))
self.queue_ui.put(("fin",))
return
images = images.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(images)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
preds = outputs.argmax(dim=1)
train_loss += loss.item() * images.size(0)
train_correct += (preds == labels).sum().item()
train_total += labels.size(0)
batch_acc = (preds == labels).sum().item() / labels.size(0)
avg_train_acc = train_correct / train_total
avg_train_loss = train_loss / train_total
# actualizar imagenes cada pocos batches para no saturar la UI
if batch_idx == 0 or (batch_idx + 1) % 5 == 0:
self.queue_ui.put((
"imagenes",
images[:4].detach().cpu(),
labels[:4].detach().cpu(),
preds[:4].detach().cpu()
))
self.queue_ui.put((
"metricas",
epoch + 1,
self.num_epochs,
batch_idx + 1,
total_batches,
loss.item(),
batch_acc,
avg_train_loss,
avg_train_acc,
None
))
if (batch_idx + 1) % 20 == 0 or batch_idx == 0:
self.queue_ui.put((
"log",
f"Epoca {epoch+1}/{self.num_epochs} | "
f"Batch {batch_idx+1}/{total_batches} | "
f"loss={loss.item():.4f} | batch_acc={batch_acc:.4f}"
))
# -------- Validacion --------
self.model.eval()
val_correct = 0
val_total = 0
val_loss = 0.0
with torch.no_grad():
for images, labels in self.val_loader:
images = images.to(self.device)
labels = labels.to(self.device)
outputs = self.model(images)
loss = self.criterion(outputs, labels)
preds = outputs.argmax(dim=1)
val_loss += loss.item() * images.size(0)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
avg_val_loss = val_loss / val_total
val_acc = val_correct / val_total
avg_train_loss = train_loss / train_total
avg_train_acc = train_correct / train_total
self.queue_ui.put((
"metricas",
epoch + 1,
self.num_epochs,
total_batches,
total_batches,
avg_train_loss,
avg_train_acc,
avg_train_loss,
avg_train_acc,
val_acc
))
self.queue_ui.put((
"log",
f"FIN EPOCA {epoch+1}/{self.num_epochs} | "
f"train_loss={avg_train_loss:.4f} | "
f"train_acc={avg_train_acc:.4f} | "
f"val_loss={avg_val_loss:.4f} | "
f"val_acc={val_acc:.4f}"
))
torch.save(self.model.state_dict(), "modelo_cifar10_tkinter.pth")
self.queue_ui.put(("log", ">> Modelo guardado en modelo_cifar10_tkinter.pth"))
self.queue_ui.put(("estado", "Entrenamiento finalizado correctamente.", "#2e7d32"))
self.queue_ui.put(("fin",))
except Exception as e:
self.queue_ui.put(("log", f">> ERROR: {str(e)}"))
self.queue_ui.put(("estado", "Ocurrio un error durante el entrenamiento.", "#c62828"))
self.queue_ui.put(("fin",))
# -----------------------------------------------------
# ACTUALIZACION DE UI
# -----------------------------------------------------
def tensor_a_imagen_tk(self, tensor, ancho=220, alto=220):
arr = tensor.permute(1, 2, 0).numpy()
arr = np.clip(arr * 255, 0, 255).astype(np.uint8)
img = Image.fromarray(arr)
img = img.resize((ancho, alto), Image.Resampling.BICUBIC)
return ImageTk.PhotoImage(img)
def mostrar_imagenes_batch(self, images, labels, preds):
self.imagenes_tk.clear()
cantidad = min(4, len(images))
for i in range(cantidad):
lbl_img, lbl_txt = self.labels_imagenes[i]
img_tk = self.tensor_a_imagen_tk(images[i])
self.imagenes_tk.append(img_tk)
lbl_img.config(image=img_tk, text="")
lbl_img.image = img_tk
clase_real = self.clases[int(labels[i].item())]
clase_pred = self.clases[int(preds[i].item())]
if clase_real == clase_pred:
color = "#2e7d32"
else:
color = "#c62828"
lbl_txt.config(
text=f"Real: {clase_real}\nPred: {clase_pred}",
fg=color
)
for i in range(cantidad, 4):
lbl_img, lbl_txt = self.labels_imagenes[i]
lbl_img.config(image="", text="Sin imagen")
lbl_img.image = None
lbl_txt.config(text="Real: -\nPred: -", fg="#263238")
def actualizar_metricas(self, epoch, total_epochs, batch, total_batches,
loss, batch_acc, train_loss, train_acc, val_acc):
texto = (
f"Epoca: {epoch}/{total_epochs}\n"
f"Batch: {batch}/{total_batches}\n"
f"Loss: {loss:.4f}\n"
f"Accuracy batch: {batch_acc:.4f}\n"
f"Train acc: {train_acc:.4f}\n"
)
if val_acc is None:
texto += "Val acc: calculando..."
else:
texto += f"Val acc: {val_acc:.4f}"
self.lbl_metricas.config(text=texto)
def procesar_cola(self):
ultimo_estado = None
ultimas_imagenes = None
ultimas_metricas = None
fin_recibido = False
try:
while True:
item = self.queue_ui.get_nowait()
tipo = item[0]
if tipo == "log":
self.escribir_log(item[1])
elif tipo == "estado":
ultimo_estado = item
elif tipo == "imagenes":
ultimas_imagenes = item
elif tipo == "metricas":
ultimas_metricas = item
elif tipo == "fin":
fin_recibido = True
except queue.Empty:
pass
if ultimo_estado is not None:
_, texto, color = ultimo_estado
self.actualizar_estado(texto, color)
if ultimas_imagenes is not None:
_, images, labels, preds = ultimas_imagenes
self.mostrar_imagenes_batch(images, labels, preds)
if ultimas_metricas is not None:
_, epoch, total_epochs, batch, total_batches, loss, batch_acc, train_loss, train_acc, val_acc = ultimas_metricas
self.actualizar_metricas(epoch, total_epochs, batch, total_batches, loss, batch_acc, train_loss, train_acc, val_acc)
if fin_recibido:
self.barra.stop()
self.root.after(50, self.procesar_cola)
# =========================================================
# MAIN
# =========================================================
if __name__ == "__main__":
root = tk.Tk()
app = AppEntrenamientoCIFAR10(root)
root.mainloop()
Aquí ya aparece una situación típica de trabajo real: el dataset se descarga automáticamente, se construyen los DataLoader, se entrena la red y además se visualiza qué está ocurriendo mientras avanza el proceso.
Una vez entrenado el modelo del punto anterior, podemos plantear un ejercicio final muy natural: permitir que el usuario elija una imagen de su computadora y pedirle a la red que intente clasificarla.
El enunciado del problema puede formularse así:
Este ejemplo conecta entrenamiento e inferencia. Primero entrenamos y guardamos el archivo modelo_cifar10_tkinter.pth. Después lo reutilizamos para clasificar imágenes nuevas.
El siguiente programa implementa esa etapa final:
import os
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageTk
# =========================================================
# MODELO
# =========================================================
class CNNPequena(nn.Module):
def __init__(self, num_clases=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, num_clases)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 32x32 -> 16x16
x = self.pool(F.relu(self.conv2(x))) # 16x16 -> 8x8
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# =========================================================
# APP
# =========================================================
class AppPrediccionCIFAR10:
def __init__(self, root):
self.root = root
self.root.title("Predicción de imágenes con modelo CIFAR-10")
self.root.geometry("1100x760")
self.root.configure(bg="#f4f6f8")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.ruta_modelo = "modelo_cifar10_tkinter.pth"
self.clases = [
"avion", "auto", "pajaro", "gato", "ciervo",
"perro", "rana", "caballo", "barco", "camion"
]
self.modelo = None
self.imagen_original_pil = None
self.imagen_preview_tk = None
self.ruta_imagen = None
self.transformacion = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
self.crear_interfaz()
self.cargar_modelo()
# -----------------------------------------------------
# INTERFAZ
# -----------------------------------------------------
def crear_interfaz(self):
contenedor = tk.Frame(self.root, bg="#f4f6f8")
contenedor.pack(fill="both", expand=True, padx=12, pady=12)
titulo = tk.Label(
contenedor,
text="Clasificación de una imagen usando un modelo entrenado",
font=("Arial", 22, "bold"),
bg="#f4f6f8",
fg="#1f2d3d"
)
titulo.pack(anchor="w")
subtitulo = tk.Label(
contenedor,
text="Selecciona una imagen desde tu computadora y el modelo intentará clasificarla en una de las clases de CIFAR-10.",
font=("Arial", 11),
bg="#f4f6f8",
fg="#546e7a"
)
subtitulo.pack(anchor="w", pady=(4, 12))
cuerpo = tk.Frame(contenedor, bg="#f4f6f8")
cuerpo.pack(fill="both", expand=True)
# Panel izquierdo
panel_izq = tk.Frame(cuerpo, bg="white", bd=1, relief="solid")
panel_izq.pack(side="left", fill="y", padx=(0, 10))
tk.Label(
panel_izq,
text="Modelo",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(15, 10))
self.lbl_estado_modelo = tk.Label(
panel_izq,
text="Cargando modelo...",
font=("Arial", 11),
bg="white",
fg="#1565c0",
justify="left",
wraplength=280
)
self.lbl_estado_modelo.pack(anchor="w", padx=15)
tk.Label(
panel_izq,
text=f"Dispositivo: {self.device}",
font=("Arial", 11),
bg="white",
fg="#37474f"
).pack(anchor="w", padx=15, pady=(10, 20))
ttk.Button(
panel_izq,
text="Seleccionar imagen",
command=self.seleccionar_imagen
).pack(fill="x", padx=15, pady=5)
ttk.Button(
panel_izq,
text="Predecir clase",
command=self.predecir_imagen
).pack(fill="x", padx=15, pady=5)
ttk.Button(
panel_izq,
text="Limpiar",
command=self.limpiar
).pack(fill="x", padx=15, pady=5)
tk.Label(
panel_izq,
text="Ruta de la imagen",
font=("Arial", 14, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(25, 8))
self.lbl_ruta = tk.Label(
panel_izq,
text="Todavía no se seleccionó ninguna imagen.",
font=("Arial", 10),
bg="white",
fg="#455a64",
justify="left",
wraplength=280
)
self.lbl_ruta.pack(anchor="w", padx=15)
tk.Label(
panel_izq,
text="Resultado",
font=("Arial", 14, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(25, 8))
self.lbl_resultado = tk.Label(
panel_izq,
text="Clase predicha: -",
font=("Arial", 16, "bold"),
bg="white",
fg="#2e7d32",
justify="left"
)
self.lbl_resultado.pack(anchor="w", padx=15)
self.lbl_confianza = tk.Label(
panel_izq,
text="Confianza: -",
font=("Arial", 11),
bg="white",
fg="#37474f",
justify="left"
)
self.lbl_confianza.pack(anchor="w", padx=15, pady=(8, 15))
# Panel derecho
panel_der = tk.Frame(cuerpo, bg="#f4f6f8")
panel_der.pack(side="left", fill="both", expand=True)
# Vista imagen
panel_imagen = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
panel_imagen.pack(fill="both", expand=True)
tk.Label(
panel_imagen,
text="Imagen seleccionada",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(15, 10))
self.canvas_imagen = tk.Label(
panel_imagen,
bg="#eaeff2",
width=500,
height=350
)
self.canvas_imagen.pack(padx=15, pady=15)
self.lbl_info_preproceso = tk.Label(
panel_imagen,
text="La imagen será redimensionada a 32x32 y convertida a tensor para que la red pueda procesarla.",
font=("Arial", 11),
bg="white",
fg="#546e7a"
)
self.lbl_info_preproceso.pack(anchor="w", padx=15, pady=(0, 15))
# Probabilidades
panel_probs = tk.Frame(panel_der, bg="white", bd=1, relief="solid")
panel_probs.pack(fill="both", expand=False, pady=(10, 0))
tk.Label(
panel_probs,
text="Probabilidades por clase",
font=("Arial", 16, "bold"),
bg="white",
fg="#1f2d3d"
).pack(anchor="w", padx=15, pady=(15, 10))
self.text_probs = tk.Text(
panel_probs,
height=12,
font=("Consolas", 11),
bg="#0f172a",
fg="#e2e8f0"
)
self.text_probs.pack(fill="both", expand=True, padx=15, pady=(0, 15))
self.text_probs.config(state="disabled")
# -----------------------------------------------------
# MODELO
# -----------------------------------------------------
def cargar_modelo(self):
if not os.path.exists(self.ruta_modelo):
self.lbl_estado_modelo.config(
text=f"No se encontró el archivo {self.ruta_modelo}",
fg="#c62828"
)
messagebox.showerror(
"Error",
f"No se encontró el archivo del modelo:\n{self.ruta_modelo}\n\n"
"Asegúrate de ejecutar primero la aplicación de entrenamiento."
)
return
try:
self.modelo = CNNPequena(num_clases=10).to(self.device)
state_dict = torch.load(self.ruta_modelo, map_location=self.device)
self.modelo.load_state_dict(state_dict)
self.modelo.eval()
self.lbl_estado_modelo.config(
text=f"Modelo cargado correctamente desde:\n{self.ruta_modelo}",
fg="#2e7d32"
)
except Exception as e:
self.lbl_estado_modelo.config(
text="Error al cargar el modelo.",
fg="#c62828"
)
messagebox.showerror("Error", f"No se pudo cargar el modelo.\n\n{e}")
# -----------------------------------------------------
# IMAGEN
# -----------------------------------------------------
def seleccionar_imagen(self):
ruta = filedialog.askopenfilename(
title="Seleccionar imagen",
filetypes=[
("Imágenes", "*.png *.jpg *.jpeg *.bmp *.webp"),
("Todos los archivos", "*.*")
]
)
if not ruta:
return
try:
imagen = Image.open(ruta).convert("RGB")
self.imagen_original_pil = imagen
self.ruta_imagen = ruta
self.mostrar_preview(imagen)
self.lbl_ruta.config(text=ruta)
self.lbl_resultado.config(text="Clase predicha: -", fg="#2e7d32")
self.lbl_confianza.config(text="Confianza: -")
self.limpiar_probabilidades()
except Exception as e:
messagebox.showerror("Error", f"No se pudo abrir la imagen.\n\n{e}")
def mostrar_preview(self, imagen_pil):
copia = imagen_pil.copy()
copia.thumbnail((500, 350))
self.imagen_preview_tk = ImageTk.PhotoImage(copia)
self.canvas_imagen.config(image=self.imagen_preview_tk)
def preparar_tensor(self, imagen_pil):
tensor = self.transformacion(imagen_pil)
tensor = tensor.unsqueeze(0) # agregar dimensión batch
return tensor.to(self.device)
# -----------------------------------------------------
# PREDICCIÓN
# -----------------------------------------------------
def predecir_imagen(self):
if self.modelo is None:
messagebox.showerror("Error", "El modelo no está cargado.")
return
if self.imagen_original_pil is None:
messagebox.showwarning("Aviso", "Primero selecciona una imagen.")
return
try:
tensor = self.preparar_tensor(self.imagen_original_pil)
with torch.no_grad():
salida = self.modelo(tensor)
probabilidades = torch.softmax(salida, dim=1)
indice = probabilidades.argmax(dim=1).item()
confianza = probabilidades[0, indice].item()
clase_predicha = self.clases[indice]
self.lbl_resultado.config(
text=f"Clase predicha: {clase_predicha}",
fg="#1565c0"
)
self.lbl_confianza.config(
text=f"Confianza: {confianza * 100:.2f}%"
)
self.mostrar_probabilidades(probabilidades[0].cpu())
except Exception as e:
messagebox.showerror("Error", f"No se pudo realizar la predicción.\n\n{e}")
# -----------------------------------------------------
# SALIDA
# -----------------------------------------------------
def mostrar_probabilidades(self, probs):
pares = list(zip(self.clases, probs.tolist()))
pares.sort(key=lambda x: x[1], reverse=True)
self.text_probs.config(state="normal")
self.text_probs.delete("1.0", "end")
self.text_probs.insert("end", "Ranking de clases:\n\n")
for clase, prob in pares:
self.text_probs.insert("end", f"{clase:<10} -> {prob * 100:6.2f}%\n")
self.text_probs.config(state="disabled")
def limpiar_probabilidades(self):
self.text_probs.config(state="normal")
self.text_probs.delete("1.0", "end")
self.text_probs.config(state="disabled")
def limpiar(self):
self.imagen_original_pil = None
self.imagen_preview_tk = None
self.ruta_imagen = None
self.canvas_imagen.config(image="")
self.lbl_ruta.config(text="Todavía no se seleccionó ninguna imagen.")
self.lbl_resultado.config(text="Clase predicha: -", fg="#2e7d32")
self.lbl_confianza.config(text="Confianza: -")
self.limpiar_probabilidades()
# =========================================================
# MAIN
# =========================================================
if __name__ == "__main__":
root = tk.Tk()
app = AppPrediccionCIFAR10(root)
root.mainloop()
Este programa no vuelve a entrenar la red. Su trabajo es cargar los pesos ya aprendidos, preparar una imagen nueva, ejecutar una pasada hacia adelante y mostrar tanto la clase más probable como el ranking completo de probabilidades.
Los datos salen de torchvision.datasets.CIFAR10. Esa clase conoce la URL oficial del dataset y, cuando usamos download=True, se encarga de bajarlo si todavía no existe en nuestra máquina.
En este programa se usa:
train_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transform
)
Eso significa que el dataset se guarda dentro de una carpeta llamada data ubicada en el directorio desde el cual ejecutamos el script.
Si ejecutamos el programa desde la carpeta del proyecto, la estructura típica que veremos será algo parecida a esta:
proyecto/
data/
cifar-10-batches-py/
cifar-10-python.tar.gz
El archivo comprimido suele descargarse como cifar-10-python.tar.gz. Luego torchvision lo descomprime automáticamente y deja disponible la carpeta cifar-10-batches-py, que contiene los archivos internos del dataset.
Es decir, normalmente no tenemos que descomprimir nada a mano. La librería hace ese trabajo por nosotros si el dataset todavía no está preparado.
Este ejemplo resuelve un problema de clasificación multiclase supervisada. Cada imagen entra a la CNN como un tensor de 3 canales y tamaño 32x32, y la red produce 10 valores de salida, uno por cada clase posible.
Durante el entrenamiento, CrossEntropyLoss compara esos 10 valores con la etiqueta correcta. Después, Adam ajusta los pesos para que la clase correcta tienda a recibir una puntuación más alta en futuras iteraciones.
La interfaz en Tkinter agrega algo muy útil desde el punto de vista didáctico: permite ver ejemplos concretos del batch actual, la clase real, la clase predicha y las métricas acumuladas mientras la red aprende.
Por eso este ejemplo ya se parece bastante más a una aplicación de trabajo real que al caso anterior con FakeData.
Algunos errores frecuentes son:
optimizer.zero_grad().softmax antes de CrossEntropyLoss.model.eval() ni torch.no_grad() en validación.Estos errores son muy comunes incluso en personas que ya entienden la teoría general.
nn.CrossEntropyLoss() es una elección muy habitual.Entrenar un clasificador de imágenes en PyTorch es el punto donde las piezas principales del pipeline de Deep Learning empiezan a encajar de verdad. La arquitectura deja de ser un objeto estático y se convierte en un sistema que aprende, se equivoca, corrige sus pesos y mejora progresivamente a partir de los datos.
Lo importante en esta etapa es dominar la lógica general: cómo se organiza el loop, qué rol cumple la pérdida, cómo interviene el optimizador y por qué la validación es necesaria para interpretar si el modelo realmente está aprendiendo algo útil.
En el próximo tema ampliaremos esta mirada centrándonos en la evaluación de modelos de visión por computadora, para entender con más precisión cómo medir el desempeño de una red más allá de observar solo la pérdida o la accuracy de entrenamiento.