-
Notifications
You must be signed in to change notification settings - Fork 129
/
hubconf.py
66 lines (51 loc) · 2.41 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from strhub.models.utils import create_model
dependencies = ['torch', 'pytorch_lightning', 'timm']
def parseq_tiny(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs):
"""
PARSeq tiny model (img_size=128x32, patch_size=8x4, d_model=192)
@param pretrained: (bool) Use pretrained weights
@param decode_ar: (bool) use AR decoding
@param refine_iters: (int) number of refinement iterations to use
"""
return create_model('parseq-tiny', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs)
def parseq(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs):
"""
PARSeq base model (img_size=128x32, patch_size=8x4, d_model=384)
@param pretrained: (bool) Use pretrained weights
@param decode_ar: (bool) use AR decoding
@param refine_iters: (int) number of refinement iterations to use
"""
return create_model('parseq', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs)
def parseq_patch16_224(pretrained: bool = False, decode_ar: bool = True, refine_iters: int = 1, **kwargs):
"""
PARSeq base model (img_size=224x224, patch_size=16x16, d_model=384)
@param pretrained: (bool) Use pretrained weights
@param decode_ar: (bool) use AR decoding
@param refine_iters: (int) number of refinement iterations to use
"""
return create_model('parseq-patch16-224', pretrained, decode_ar=decode_ar, refine_iters=refine_iters, **kwargs)
def abinet(pretrained: bool = False, iter_size: int = 3, **kwargs):
"""
ABINet model (img_size=128x32)
@param pretrained: (bool) Use pretrained weights
@param iter_size: (int) number of refinement iterations to use
"""
return create_model('abinet', pretrained, iter_size=iter_size, **kwargs)
def trba(pretrained: bool = False, **kwargs):
"""
TRBA model (img_size=128x32)
@param pretrained: (bool) Use pretrained weights
"""
return create_model('trba', pretrained, **kwargs)
def vitstr(pretrained: bool = False, **kwargs):
"""
ViTSTR small model (img_size=128x32, patch_size=8x4, d_model=384)
@param pretrained: (bool) Use pretrained weights
"""
return create_model('vitstr', pretrained, **kwargs)
def crnn(pretrained: bool = False, **kwargs):
"""
CRNN model (img_size=128x32)
@param pretrained: (bool) Use pretrained weights
"""
return create_model('crnn', pretrained, **kwargs)