| | import torch
|
| | import torch.nn as nn
|
| | import os
|
| | from model import MiniText
|
| |
|
| |
|
| |
|
| |
|
| | SEQ_LEN = 64
|
| | EPOCHS = 12000
|
| | LR = 1e-4
|
| | SAVE_EVERY = 2000
|
| | CHECKPOINT_PATH = "checkpoint.pt"
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | with open("dataset.txt", "rb") as f:
|
| | data = torch.tensor(list(f.read()), dtype=torch.long)
|
| |
|
| |
|
| |
|
| |
|
| | model = MiniText()
|
| | optimizer = torch.optim.Adam(model.parameters(), lr=LR)
|
| | loss_fn = nn.CrossEntropyLoss()
|
| |
|
| | start_epoch = 0
|
| |
|
| |
|
| |
|
| |
|
| | if os.path.exists(CHECKPOINT_PATH):
|
| | print("Checkpoint encontrado, retomando treino...")
|
| | checkpoint = torch.load(CHECKPOINT_PATH)
|
| | model.load_state_dict(checkpoint["model"])
|
| | optimizer.load_state_dict(checkpoint["optimizer"])
|
| | start_epoch = checkpoint["epoch"] + 1
|
| | else:
|
| | print("Nenhum checkpoint encontrado, treino do zero.")
|
| |
|
| |
|
| |
|
| |
|
| | def get_batch():
|
| | idx = torch.randint(0, len(data) - SEQ_LEN - 1, (1,))
|
| | x = data[idx:idx + SEQ_LEN].unsqueeze(0)
|
| | y = data[idx + 1:idx + SEQ_LEN + 1].unsqueeze(0)
|
| | return x, y
|
| |
|
| |
|
| |
|
| |
|
| | for epoch in range(start_epoch, EPOCHS):
|
| | x, y = get_batch()
|
| | logits, _ = model(x)
|
| | loss = loss_fn(logits.view(-1, 256), y.view(-1))
|
| |
|
| | optimizer.zero_grad()
|
| | loss.backward()
|
| | optimizer.step()
|
| |
|
| | print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {loss.item():.4f}")
|
| |
|
| |
|
| | if (epoch + 1) % SAVE_EVERY == 0:
|
| | torch.save({
|
| | "epoch": epoch,
|
| | "model": model.state_dict(),
|
| | "optimizer": optimizer.state_dict()
|
| | }, CHECKPOINT_PATH)
|
| | print("Checkpoint salvo.")
|
| |
|
| |
|
| |
|
| |
|
| | torch.save(model.state_dict(), "minitext.pt")
|
| | print("Treino finalizado. Modelo salvo em minitext.pt")
|
| |
|