1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
import open_clip
import transformers
from multilingual_clip import pt_multilingual_clip, Config_MCLIP
DEVICE = 'cuda'
class CudaMultilingualCLIP(transformers.PreTrainedModel):
"""to support GPU on encoding text
"""
config_class = Config_MCLIP.MCLIPConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
out_features=config.numDims)
def forward(self, txt, tokenizer, device=DEVICE):
txt_tok = tokenizer(txt, padding=True, return_tensors='pt').to(device)
embs = self.transformer(**txt_tok)[0]
att = txt_tok['attention_mask']
embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]
return self.LinearTransformation(embs)
@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
model.load_state_dict(state_dict)
return model, [], [], []
|