| import torch
|
|
|
| class SimpleTokenizer:
|
| def __init__(self, vocab_path):
|
| self.char_to_idx = torch.load(vocab_pth)
|
|
|
|
|
| if '<unk>' not in self.char_to_idx:
|
| self.char_to_idx['<unk>'] = max(self.char_to_idx.values()) + 1
|
|
|
| self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}
|
|
|
|
|
| def encode(self, text):
|
| return [self.char_to_idx.get(c, self.char_to_idx.get('<unk>', 0)) for c in text]
|
|
|
| def decode(self, indices):
|
| return ''.join([self.idx_to_char.get(i, '') for i in indices])
|
|
|
|
|
| vocab_path = 'vocab.pth'
|
| tokenizer = SimpleTokenizer(vocab_path)
|
|
|
| text = "Hello, world!"
|
| tokens = tokenizer.encode(text)
|
| print(tokens)
|
|
|
| decoded_text = tokenizer.decode(tokens)
|
| print(decoded_text)
|
|
|