| | |
| | """ |
| | @author:cb |
| | @contact:chenbo@bat100.net |
| | @time:2023/5/30 14:21 |
| | @filename:tokenization.py |
| | @software:PyCharm |
| | @description: |
| | """ |
| | import re |
| | from transformers import FSMTTokenizer as fsmt |
| |
|
| |
|
| | class FSMTTokenizer(fsmt): |
| | space_re = re.compile('\s*(?=[^a-zA-Z0-9 ]+)\s*') |
| |
|
| | def moses_tokenize(self, text, lang): |
| | if lang not in self.cache_moses_tokenizer: |
| | moses_tokenizer = self.sm.MosesTokenizer(lang=lang) |
| | self.cache_moses_tokenizer[lang] = moses_tokenizer |
| | return self.cache_moses_tokenizer[lang].tokenize( |
| | text, aggressive_dash_splits=True, return_str=False, escape=False |
| | ) |
| |
|
| | def _switch_to_input_mode(self): |
| | self.lang_prefix, self.lang_prefix_id = 'en', 64812 |
| |
|
| | def _switch_to_target_mode(self): |
| | self.lang_prefix, self.lang_prefix_id = 'zh', 64870 |
| |
|
| | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| | """ |
| | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and |
| | adding special tokens. A FAIRSEQ Transformer sequence has the following format: |
| | |
| | - single sequence: `<s> X </s>` |
| | - pair of sequences: `<s> A </s> B </s>` |
| | |
| | Args: |
| | token_ids_0 (`List[int]`): |
| | List of IDs to which the special tokens will be added. |
| | token_ids_1 (`List[int]`, *optional*): |
| | Optional second list of IDs for sequence pairs. |
| | |
| | Returns: |
| | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. |
| | """ |
| | sep = [self.sep_token_id] |
| | token_ids_0 = [self.lang_prefix_id] + token_ids_0 |
| | |
| | if token_ids_1 is None: |
| | return token_ids_0 + sep |
| | return token_ids_0 + sep + token_ids_1 + sep |
| |
|
| | def moses_pipeline(self, text, lang): |
| | text = self.moses_punct_norm(text, lang) |
| | return text |
| |
|
| | def _tokenize(self, text, lang="en", bypass_tokenizer=False): |
| | """ |
| | 原版FSMTTokenizer会把中文标点英文化,故重写 |
| | :param text: |
| | :param lang: |
| | :param bypass_tokenizer: |
| | :return: |
| | """ |
| | if self.do_lower_case: |
| | text = text.lower() |
| | if bypass_tokenizer: |
| | text = text.split() |
| | else: |
| | text = self.moses_pipeline(text, lang=self.lang_prefix) |
| | text = self.moses_tokenize(text, lang=self.lang_prefix) |
| |
|
| | split_tokens = [] |
| | for token in text: |
| | if token: |
| | split_tokens.extend(list(self.bpe(token).split(" "))) |
| |
|
| | return split_tokens |
| |
|
| | def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): |
| | """ |
| | |
| | :param text: |
| | :param is_split_into_words: |
| | :param kwargs: |
| | :return: |
| | """ |
| | if kwargs.get('src', True): |
| | self._switch_to_input_mode() |
| | else: |
| | self._switch_to_target_mode() |
| | return super(FSMTTokenizer, self).prepare_for_tokenization(text, is_split_into_words=False, **kwargs) |
| | |
| | def convert_tokens_to_string(self, tokens): |
| | """ |
| | 删除非英文字母前后的空格,业务上处理更合适 |
| | :param tokens: |
| | :return: |
| | """ |
| | tokens = super(FSMTTokenizer, self).convert_tokens_to_string(tokens) |
| | tokens = FSMTTokenizer.space_re.sub('', tokens) |
| | return tokens |
| |
|
| |
|
| | if __name__ == '__main__': |
| | tokenizer = FSMTTokenizer.from_pretrained(r'./') |
| | r = tokenizer.tokenize(['hello', 'hi']) |
| | print(r) |