相信很多小伙伴平时都使用内网进行工作,这些网络是无法连接huggingface的,使用魔塔加载模型网络断断续续的很容易失败。但是bert_score只接收一个模型名,然后自动在huggingface下载或在本地缓存加载。这个缓存跟huggingface官方缓存是不同的。
解决办法1:修改bert_score源码。bert_score虽说是只接受模型名,但内部还是通过AutoTokenizer.from_pretrained和AutoModel.from_pretrained这两个方法加载模型,相信这两个方法大家都很熟悉了。因此只需要在源码中添加自己的model_path,并且把源码中的model_type这个参数改为model_path
源码:
def get_model(model_type, num_layers, all_layers=None):
if model_type.startswith("scibert"):
model = AutoModel.from_pretrained(cache_scibert(model_type))
elif "t5" in model_type:
from transformers import T5EncoderModel
model = T5EncoderModel.from_pretrained(model_type)
else:
model = AutoModel.from_pretrained(model_type)
model.eval()
def get_tokenizer(model_type, use_fast=False):
if model_type.startswith("scibert"):
model_type = cache_scibert(model_type)
if version.parse(trans_version) >= version.parse("4.0.0"):
tokenizer = AutoTokenizer.from_pretrained(model_type, use_fast=use_fast)
else:
assert not use_fast, "Fast tokenizer is not available for version < 4.0.0"
tokenizer = AutoTokenizer.from_pretrained(model_type)
return tokenizer
改为:
def get_model(model_type, num_layers, all_layers=None):
model_path = 'xxx'
if model_type.startswith("scibert"):
model = AutoModel.from_pretrained(cache_scibert(model_type))
elif "t5" in model_type:
from transformers import T5EncoderModel
model = T5EncoderModel.from_pretrained(model_path)
else:
model = AutoModel.from_pretrained(model_path)
model.eval()
def get_tokenizer(model_type, use_fast=False):
if model_type.startswith("scibert"):
model_type = cache_scibert(model_type)
model_path = 'xxx'
if version.parse(trans_version) >= version.parse("4.0.0"):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)
else:
assert not use_fast, "Fast tokenizer is not available for version < 4.0.0"
tokenizer = AutoTokenizer.from_pretrained(model_path)
return tokenizer
解决办法2:让bert_score找到缓存模型,相信聪明的小伙伴已经在前面的代码中看到bert_score是如何加载缓存模型的。如果要加载缓存模型,model_type字段加载的模型前要加scibert-前缀。并且需要把本地模型放在指定的目录下。可以看出这个函数下载的模型有它自己的命名规则,需要根据它的规则对自己的模型文件做出相应修改。
下面给出部分源码
def cache_scibert(model_type, cache_folder="~/.cache/torch/transformers"):
if not model_type.startswith("scibert"):
return model_type
underscore_model_type = model_type.replace("-", "_")
cache_folder = os.path.abspath(os.path.expanduser(cache_folder))
filename = os.path.join(cache_folder, underscore_model_type)