-
Notifications
You must be signed in to change notification settings - Fork 16
/
modeling.go
22 lines (20 loc) · 943 Bytes
/
modeling.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package transformer
import (
"github.com/sugarme/gotch"
"github.com/sugarme/transformer/pretrained"
)
// LoadConfig loads pretrained model data from local or remote file.
//
// Parameters:
// - `model` pretrained Model (any model type that implements pretrained `Model` interface)
// - `modelNameOrPath` is a string of either
// + Model name or
// + File name or path or
// + URL to remote file
// If `modelNameOrPath` is resolved, function will cache data using `TransformerCache`
// environment if existing, otherwise it will be cached in `$HOME/.cache/transformers/` directory.
// If `modleNameOrPath` is valid URL, file will be downloaded and cached.
// Finally, model weights will be loaded to `varstore`.
func LoadModel(model pretrained.Model, modelNameOrPath string, config pretrained.Config, customParams map[string]interface{}, device gotch.Device) error {
return model.Load(modelNameOrPath, config, customParams, device)
}