AI Course/week08_pytorch/week8_pytorch.py

Course file

week8_pytorch.py

week08_pytorch/week8_pytorch.py

"""Week 8: a tiny PyTorch classifier on synthetic data."""

from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn


def make_dataset(seed: int = 7, n_per_class: int = 80) -> tuple[torch.Tensor, torch.Tensor]:
    rng = np.random.default_rng(seed)
    class_a = rng.normal(loc=(-1.4, -1.0), scale=(0.45, 0.45), size=(n_per_class, 2))
    class_b = rng.normal(loc=(1.2, 1.1), scale=(0.45, 0.45), size=(n_per_class, 2))

    features = np.vstack([class_a, class_b]).astype(np.float32)
    labels = np.concatenate([np.zeros(n_per_class), np.ones(n_per_class)]).astype(np.float32)

    return torch.from_numpy(features), torch.from_numpy(labels).unsqueeze(1)


def build_model(hidden_size: int = 16) -> nn.Module:
    return nn.Sequential(
        nn.Linear(2, hidden_size),
        nn.Tanh(),
        nn.Linear(hidden_size, 1),
    )


def main() -> None:
    torch.manual_seed(7)
    device = torch.device("cpu")
    features, labels = make_dataset()
    features, labels = features.to(device), labels.to(device)

    model = build_model().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
    loss_fn = nn.BCEWithLogitsLoss()

    losses: list[float] = []
    epochs = 200

    for epoch in range(epochs):
        model.train()
        logits = model(features)
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(float(loss.item()))
        if epoch % 25 == 0 or epoch == epochs - 1:
            print(f"epoch={epoch:03d} loss={loss.item():.4f}")

    model.eval()
    with torch.no_grad():
        probs = torch.sigmoid(model(features))
        predictions = (probs >= 0.5).float()
        accuracy = (predictions == labels).float().mean().item()

    output_path = Path(__file__).with_name("week8_loss_curve.png")
    plt.figure(figsize=(7, 4))
    plt.plot(losses, color="tab:green", linewidth=2)
    plt.title("PyTorch Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Binary cross-entropy")
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

    print(f"\nFinal accuracy: {accuracy * 100:.2f}%")
    print(f"Saved loss curve to: {output_path}")


if __name__ == "__main__":
    main()