-
Notifications
You must be signed in to change notification settings - Fork 6
/
hubconf.py
124 lines (115 loc) · 4.74 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
dependencies = ["torch"]
from torch import hub
from channelvit.backbone.channel_vit import channelvit_small
from channelvit.backbone.hcs_channel_vit import hcs_channelvit_small
def imagenet_channelvit_small_p16_DINO(pretrained=True, *args, **kwargs):
"""
Pretrained ChannelViT-Small model (patch size = 16) trained on ImageNet using DINO
"""
model = channelvit_small(patch_size=16, in_chans=3, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/imagenet_channelvit_small_p16_DINO.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model
def imagenet_channelvit_small_p16_with_hcs_supervised(pretrained=True, *args, **kwargs):
"""
Pretrained Supervised ChannelViT-Small model (patch size = 16) trained on ImageNet
"""
model = hcs_channelvit_small(patch_size=16, in_chans=3, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/imagenet_channelvit_small_p16_with_hcs_supervised.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model
def cpjump_cellpaint_channelvit_small_p8_with_hcs_supervised(pretrained=True, *args, **kwargs):
"""
Pretrained Supervised ChannelViT-Small model (patch size = 8) trained on
CellPainting channels from JUMP-CP (subset)
"""
model = hcs_channelvit_small(patch_size=8, in_chans=5, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/cpjump_cellpaint_channelvit_small_p8_with_hcs_supervised.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model
def cpjump_cellpaint_bf_channelvit_small_p8_with_hcs_supervised(pretrained=True, *args, **kwargs):
"""
Pretrained Supervised ChannelViT-Small model (patch size = 8) trained on
CellPainting + Brightfield channels from JUMP-CP (subset)
"""
model = hcs_channelvit_small(patch_size=8, in_chans=8, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/cpjump_cellpaint_bf_channelvit_small_p8_with_hcs_supervised.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model
def so2sat_channelvit_small_p8_with_hcs_random_split_supervised(pretrained=True, *args, **kwargs):
"""
Pretrained Supervised ChannelViT-Small model (patch size = 8) trained on
all channels from So2Sat dataset (random split)
"""
model = hcs_channelvit_small(patch_size=8, in_chans=18, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
# TODO: replace with github release link
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/so2sat_channelvit_small_p8_with_hcs_random_split_supervised.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model
def so2sat_channelvit_small_p8_with_hcs_hard_split_supervised(pretrained=True, *args, **kwargs):
"""
Pretrained Supervised ChannelViT-Small model (patch size = 8) trained on
all channels from So2Sat dataset (hard split)
"""
model = hcs_channelvit_small(patch_size=8, in_chans=18, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/so2sat_channelvit_small_p8_with_hcs_hard_split_supervised.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model
def camelyon_channelvit_small_p8_with_hcs_supervised(pretrained=True, *args, **kwargs):
"""
Pretrained Supervised ChannelViT-Small model (patch size = 8) trained on
all channels from WILDS Camelyon17 dataset
"""
model = hcs_channelvit_small(patch_size=8, in_chans=3, *args, **kwargs)
if pretrained:
model.load_state_dict(
hub.load_state_dict_from_url(
"https://github.com/insitro/ChannelViT/releases/download/v1.0.0/camelyon_channelvit_small_p8_with_hcs_supervised.pth",
progress=True
)
)
# Set the model to evaluation mode
model.eval()
return model