-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
64 lines (51 loc) · 2.13 KB
/
model.py
File metadata and controls
64 lines (51 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
TEXT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
class ResNet18ImageEncoder(nn.Module):
def __init__(self, pretrained: bool = True):
super().__init__()
from torchvision.models import ResNet18_Weights, resnet18
weights = ResNet18_Weights.DEFAULT if pretrained else None
model = resnet18(weights=weights)
self.backbone = nn.Sequential(*list(model.children())[:-1])
self.enc_dim = model.fc.in_features
def forward_features(self, imgs: torch.Tensor) -> torch.Tensor:
x = self.backbone(imgs)
return torch.flatten(x, 1)
def forward_encoder(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs).unsqueeze(1)
def features(self, imgs: torch.Tensor) -> torch.Tensor:
return self.forward_features(imgs)
class _TextAndImageNamespace(nn.Module):
def __init__(self, image: nn.Module, tokenizer_name: str, max_len: int):
super().__init__()
self.image = image
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.max_len = max_len
def tokenize(self, texts, device):
return self.tokenizer(
list(texts),
padding=True,
truncation=True,
max_length=self.max_len,
return_tensors="pt",
).to(device)
class GelattoModel(nn.Module):
def __init__(
self,
max_len: int = 128,
embed_dim: int = 256,
tokenizer_name: str = TEXT_MODEL,
resnet_pretrained: bool = True,
**_unused,
):
super().__init__()
image = ResNet18ImageEncoder(pretrained=resnet_pretrained)
self.cap = _TextAndImageNamespace(image=image, tokenizer_name=tokenizer_name, max_len=max_len)
self.image = self.cap.image
self.text = AutoModel.from_pretrained(tokenizer_name)
hidden = self.text.config.hidden_size
self.img_proj = nn.Linear(self.image.enc_dim, embed_dim)
self.txt_proj = nn.Linear(hidden, embed_dim)
self.logit_scale = nn.Parameter(torch.tensor(2.6593))