model
Bases: Module
, function
The base model class of the RPN model in the tinyBIG toolkit.
It inherits from the torch.nn.Module class, which also inherits the "state_dict" and "load_state_dict" methods from the base class.
...
Attributes:
Name | Type | Description |
---|---|---|
name |
str, default = 'base_metric'
|
Name of the model. |
Methods:
Name | Description |
---|---|
__init__ |
It performs the initialization of the model |
save_ckpt |
It saves the model state as checkpoint to file. |
load_ckpt |
It loads the model state from a file. |
__call__ |
It reimplementation the build-in callable method. |
forward |
The forward method of the model. |
Source code in tinybig/module/base_model.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
|
__init__(name='model_name', device='cpu', *args, **kwargs)
The initialization method of the base model class.
It initializes a model object based on the provided model parameters.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name
|
str
|
The name of the model, with default value "model_name". |
'model_name'
|
Returns:
Type | Description |
---|---|
object
|
The initialized model object. |
Source code in tinybig/module/base_model.py
forward(*args, **kwargs)
abstractmethod
The forward method of the model.
It is declared to be an abstractmethod and needs to be implemented in the inherited RPN model classes. This callable method accepts the data instances as the input and generate the desired outputs.
Returns:
Type | Description |
---|---|
Tensor
|
The model generated outputs. |
Source code in tinybig/module/base_model.py
load_ckpt(cache_dir='./ckpt', checkpoint_file='checkpoint', strict=True)
The model state checkpoint loading method.
It loads the model state from the provided checkpoint file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cache_dir
|
str
|
The cache directory of the model checkpoint file. |
'./ckpt'
|
checkpoint_file
|
str
|
The checkpoint file name. |
'checkpoint'
|
strict
|
bool
|
The boolean tag of whether the model state loading follows the strict configuration checking. |
True
|
Returns:
Type | Description |
---|---|
None
|
This method doesn't have return values. |
Source code in tinybig/module/base_model.py
save_ckpt(cache_dir='./ckpt', checkpoint_file='checkpoint')
The model state checkpoint saving method.
It saves the current model state to a checkpoint file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cache_dir
|
The cache directory of the model checkpoint file. |
'./ckpt'
|
|
checkpoint_file
|
The checkpoint file name. |
'checkpoint'
|
Returns:
Type | Description |
---|---|
None
|
This method doesn't have return values. |