-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathget_tokenlizer.py
More file actions
29 lines (24 loc) · 1.31 KB
/
get_tokenlizer.py
File metadata and controls
29 lines (24 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import AutoTokenizer, BertModel, BertTokenizer, RobertaModel, RobertaTokenizerFast
import os
def get_tokenlizer(text_encoder_type):
if not isinstance(text_encoder_type, str):
# print("text_encoder_type is not a str")
if hasattr(text_encoder_type, "text_encoder_type"):
text_encoder_type = text_encoder_type.text_encoder_type
elif text_encoder_type.get("text_encoder_type", False):
text_encoder_type = text_encoder_type.get("text_encoder_type")
elif os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type):
pass
else:
raise ValueError(
"Unknown type of text_encoder_type: {}".format(type(text_encoder_type))
)
print("final text_encoder_type: {}".format(text_encoder_type))
tokenizer = AutoTokenizer.from_pretrained(text_encoder_type)
return tokenizer
def get_pretrained_language_model(text_encoder_type):
if text_encoder_type == "bert-base-uncased" or (os.path.isdir(text_encoder_type) and os.path.exists(text_encoder_type)):
return BertModel.from_pretrained(text_encoder_type)
if text_encoder_type == "roberta-base":
return RobertaModel.from_pretrained(text_encoder_type)
raise ValueError("Unknown text_encoder_type {}".format(text_encoder_type))