| import torch |
| import transformers |
| import random |
|
|
| |
| config = transformers.AutoConfig.from_pretrained("bert-base-uncased") |
| tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
| |
| model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", config=config) |
|
|
| |
| num_epochs = 3 |
| batch_size = 32 |
| learning_rate = 2e-5 |
|
|
| |
| optimizer = transformers.AdamW(model.parameters(), lr=learning_rate) |
|
|
| |
| def collate_fn(data): |
| input_ids = torch.tensor([tokenizer(text, padding="max_length", truncation=True)["input_ids"] for text in data["text"]]) |
| attention_mask = torch.tensor([tokenizer(text, padding="max_length", truncation=True)["attention_mask"] for text in data["text"]]) |
| labels = torch.tensor(data["label"]) |
| return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
|
|
| |
| def split_data(data, validation_size=0.2): |
| validation_indices = random.sample(range(len(data)), int(len(data) * validation_size)) |
| train_data = [] |
| val_data = [] |
| for i, item in enumerate(data): |
| if i in validation_indices: |
| val_data.append(item) |
| else: |
| train_data.append(item) |
| return train_data, val_data |
|
|
| |
| train_data, val_data = split_data(train_data, validation_size=0.2) |
|
|
| |
| for epoch in range(num_epochs): |
| |
| train_loader = transformers.DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn) |
|
|
| |
| model.train() |
| for batch in train_loader: |
| optimizer.zero_grad() |
| outputs = model(**batch) |
| loss = outputs.loss |
| loss.backward() |
| optimizer.step() |
| |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| val_loss = 0.0 |
| for batch in val_loader: |
| outputs = model(**batch) |
| val_loss += outputs.loss.item() |
|
|
| print("Epoch {}: Train Loss: {:.4f} Val Loss: {:.4f}".format(epoch + 1, train_loss / len(train_loader), val_loss / len(val_loader))) |
|
|
| |
| model.save_pretrained("finetuned_model") |
|
|