Skip to content

Commit

Permalink
Make Flair conversion work with more models
Browse files Browse the repository at this point in the history
found on HuggingFace hub
  • Loading branch information
marco-nicola committed Sep 18, 2022
1 parent 64b4552 commit 4da856d
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 28 deletions.
35 changes: 35 additions & 0 deletions pkg/converter/flair/conversion/builtins/getattr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package builtins

import (
"fmt"

"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion"
)

type Getattr struct{}

func (Getattr) Call(args ...any) (any, error) {
if len(args) != 2 && len(args) != 3 {
return nil, fmt.Errorf("builtins.getattr: want 2 or 3 args, got %d: %#v", len(args), args)
}

object, ok := args[0].(conversion.PyAttributeGettable)
if !ok {
return nil, fmt.Errorf("builtins.getattr: 1st arg (object) does not satisfy PyAttributeGettable interface: %T", args[0])
}

name, ok := args[1].(string)
if !ok {
return nil, fmt.Errorf("builtins.getattr: want 2nd arg (name) to be string, got %T: %#v", args[1], args[1])
}

value, exists, err := object.PyGetAttribute(name)
if err != nil {
return nil, fmt.Errorf("builtins.getattr(%#v): PyGetAttribute failed: %w", args, err)
}

if len(args) == 3 && !exists {
return args[2], nil
}
return value, nil
}
3 changes: 3 additions & 0 deletions pkg/converter/flair/conversion/builtins/int.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package builtins

type Int struct{}
25 changes: 25 additions & 0 deletions pkg/converter/flair/conversion/collections/defaultdict.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package collections

import (
"fmt"

"github.com/nlpodyssey/gopickle/types"
)

type DefaultDictClass struct{}

type DefaultDict struct {
*types.Dict
DefaultFactory any
}

func (DefaultDictClass) Call(args ...any) (any, error) {
if len(args) != 1 {
return nil, fmt.Errorf("DefaultDictClass: want 1 argument, got %d: %#v", len(args), args)
}

return &DefaultDict{
Dict: types.NewDict(),
DefaultFactory: args[0],
}, nil
}
10 changes: 7 additions & 3 deletions pkg/converter/flair/conversion/flair/dictionary.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@ import (
"fmt"

"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/collections"
"github.com/nlpodyssey/gopickle/types"
)

type DictionaryClass struct{}

type Dictionary struct {
Item2Idx map[string]int
Idx2Item []string
MultiLabel bool
Item2Idx map[string]int
Idx2Item []string
Item2IdxNotEncoded *collections.DefaultDict
MultiLabel bool
}

func (DictionaryClass) PyNew(args ...any) (any, error) {
Expand All @@ -40,6 +42,8 @@ func (d *Dictionary) PyDictSet(k, v any) (err error) {
if err == nil {
err = d.setIdx2Item(l)
}
case "item2idx_not_encoded":
err = conversion.AssignAssertedType(v, &d.Item2IdxNotEncoded)
case "multi_label":
err = conversion.AssignAssertedType(v, &d.MultiLabel)
default:
Expand Down
9 changes: 9 additions & 0 deletions pkg/converter/flair/conversion/flair/flairembeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ type FlairEmbeddings struct {
CharsPerChunk int
embeddingLength int
PretrainedModelArchiveMap map[string]string
InstanceParameters *types.Dict
WithWhitespace bool // Default: true
TokenizedLM bool // Default: true
LM *LanguageModel
}

Expand Down Expand Up @@ -66,6 +69,12 @@ func (f *FlairEmbeddings) PyDictSet(k, v any) (err error) {
if err == nil {
err = conversion.AssignDictToMap(d, &f.PretrainedModelArchiveMap)
}
case "instance_parameters":
err = conversion.AssignAssertedType(v, &f.InstanceParameters)
case "with_whitespace":
err = conversion.AssignAssertedType(v, &f.WithWhitespace)
case "tokenized_lm":
err = conversion.AssignAssertedType(v, &f.TokenizedLM)
case "detach", "cache": // TODO
default:
err = fmt.Errorf("unexpected key with value %#v", v)
Expand Down
17 changes: 10 additions & 7 deletions pkg/converter/flair/conversion/flair/languagemodel.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@ type LanguageModelClass struct{}

type LanguageModel struct {
torch.Module
Dictionary *Dictionary
IsForwardLm bool
Dropout float64
HiddenSize int
EmbeddingSize int
NLayers int
NOut int
Dictionary *Dictionary
IsForwardLm bool
Dropout float64
HiddenSize int
EmbeddingSize int
NLayers int
NOut int
DocumentDelimiter string

Encoder *torch.SparseEmbedding
Decoder *torch.Linear
Expand Down Expand Up @@ -69,6 +70,8 @@ func (l *LanguageModel) PyDictSet(k, v any) (err error) {
if v != nil {
err = fmt.Errorf("only nil is supported, got %T: %#v", v, v)
}
case "document_delimiter":
err = conversion.AssignAssertedType(v, &l.DocumentDelimiter)
default:
err = fmt.Errorf("unexpected key with value %#v", v)
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/converter/flair/conversion/flair/unpickling.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@ import (
"fmt"
"io"

"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/builtins"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/collections"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/gensim"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/numpy"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/torch"
"github.com/nlpodyssey/gopickle/pickle"
)

var allClasses = map[string]any{
"builtins.getattr": builtins.Getattr{},
"builtins.int": builtins.Int{},
"collections.defaultdict": collections.DefaultDictClass{},
"flair.data.Dictionary": DictionaryClass{},
"flair.embeddings.FlairEmbeddings": FlairEmbeddingsClass{},
"flair.embeddings.StackedEmbeddings": StackedEmbeddingsClass{},
"flair.embeddings.token.StackedEmbeddings": StackedEmbeddingsClass{},
"flair.embeddings.WordEmbeddings": WordEmbeddingsClass{},
"flair.embeddings.token.FlairEmbeddings": FlairEmbeddingsClass{},
"flair.embeddings.token.StackedEmbeddings": StackedEmbeddingsClass{},
"flair.embeddings.token.WordEmbeddings": WordEmbeddingsClass{},
"flair.models.language_model.LanguageModel": LanguageModelClass{},
"gensim.models.keyedvectors.Vocab": gensim.VocabClass{},
Expand All @@ -28,11 +34,11 @@ var allClasses = map[string]any{
"numpy.dtype": numpy.DTypeClass{},
"numpy.ndarray": numpy.NDArrayClass{},
"torch._utils._rebuild_parameter": torch.RebuildParameter{},
"torch.backends.cudnn.rnn.Unserializable": torch.RNNUnserializableClass{},
"torch.nn.modules.dropout.Dropout": torch.DropoutClass{},
"torch.nn.modules.linear.Linear": torch.LinearClass{},
"torch.nn.modules.rnn.LSTM": torch.LSTMClass{},
"torch.nn.modules.sparse.Embedding": torch.SparseEmbeddingClass{},
"torch.backends.cudnn.rnn.Unserializable": torch.RNNUnserializableClass{},
}

func newUnpickler(r io.Reader) pickle.Unpickler {
Expand Down
29 changes: 23 additions & 6 deletions pkg/converter/flair/conversion/flair/wordembeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ import (
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/gensim"
"github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/torch"
"github.com/nlpodyssey/gopickle/types"
"github.com/nlpodyssey/spago/mat"
)

type WordEmbeddingsClass struct{}

type WordEmbeddings struct {
TokenEmbeddingsModule
Embeddings string
Name string
StaticEmbeddings bool
Embedding *torch.Embedding
Vocab map[string]int
embeddingLength int
Embeddings string
Name string
StaticEmbeddings bool
Embedding *torch.Embedding
Vocab map[string]int
InstanceParameters *types.Dict
embeddingLength int
}

var _ TokenEmbeddings = &WordEmbeddings{}
Expand All @@ -48,6 +50,10 @@ func (w *WordEmbeddings) PyDictSet(k, v any) (err error) {
switch k {
case "embeddings":
err = conversion.AssignAssertedType(v, &w.Embeddings)
case "get_cached_vec":
// present on older models, can be ignored
case "instance_parameters":
err = conversion.AssignAssertedType(v, &w.InstanceParameters)
case "name":
err = conversion.AssignAssertedType(v, &w.Name)
case "static_embeddings":
Expand Down Expand Up @@ -93,6 +99,17 @@ func (w *WordEmbeddings) setPrecomputedWordEmbeddings(kv *gensim.KeyedVectors) e
return nil
}

func (w *WordEmbeddings) PyGetAttribute(name string) (value any, exists bool, err error) {
switch name {
case "get_cached_vec":
// this ignores the get_cached_vec method when loading older versions
// it is needed for compatibility reasons
return nil, true, nil
default:
return nil, false, fmt.Errorf("WordEmbeddings: unexpected __getattribute__(%q)", name)
}
}

func (w *WordEmbeddings) LoadStateDictEntry(string, any) error {
return fmt.Errorf("WordEmbeddings: loading from state dict entry not implemented")
}
21 changes: 12 additions & 9 deletions pkg/converter/flair/conversion/torch/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ import (
)

type Module struct {
Training bool
Parameters *types.OrderedDict
Buffers *types.OrderedDict
BackwardHooks *types.OrderedDict
ForwardHooks *types.OrderedDict
ForwardPreHooks *types.OrderedDict
StateDictHooks *types.OrderedDict
LoadStateDictPreHooks *types.OrderedDict
Modules *types.OrderedDict
Training bool
Parameters *types.OrderedDict
Buffers *types.OrderedDict
BackwardHooks *types.OrderedDict
ForwardHooks *types.OrderedDict
ForwardPreHooks *types.OrderedDict
StateDictHooks *types.OrderedDict
LoadStateDictPreHooks *types.OrderedDict
Modules *types.OrderedDict
NonPersistentBuffersSet *types.Set
}

func GetSubModule[T any](mod Module, name string) (v T, err error) {
Expand Down Expand Up @@ -70,6 +71,8 @@ func (m *Module) PyDictSet(k, v any) (err error) {
err = conversion.AssignAssertedType(v, &m.ForwardPreHooks)
case "_state_dict_hooks":
err = conversion.AssignAssertedType(v, &m.StateDictHooks)
case "_non_persistent_buffers_set":
err = conversion.AssignAssertedType(v, &m.NonPersistentBuffersSet)
case "_load_state_dict_pre_hooks":
err = conversion.AssignAssertedType(v, &m.LoadStateDictPreHooks)
case "_modules":
Expand Down
13 changes: 13 additions & 0 deletions pkg/converter/flair/conversion/torch/rnn.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type RNNBase struct {
Bidirectional bool
ProjSize int
FlatWeightsNames []string
FlatWeights []*Parameter
AllWeights [][]string
Parameters map[string]*Parameter
}
Expand Down Expand Up @@ -202,6 +203,18 @@ func (r *RNNBase) PyDictSet(k, v any) (err error) {
err = conversion.AssignAssertedType(v, &r.Bidirectional)
case "_all_weights":
err = r.convertAndSetAllWeights(v)
case "_flat_weights":
var l *types.List
err = conversion.AssignAssertedType(v, &l)
if err == nil {
err = conversion.AssignListToSlice(l, &r.FlatWeights)
}
case "_flat_weights_names":
var l *types.List
err = conversion.AssignAssertedType(v, &l)
if err == nil {
err = conversion.AssignListToSlice(l, &r.FlatWeightsNames)
}
case "_data_ptrs", "_param_buf_size":
default:
err = fmt.Errorf("unexpected key with value %#v", v)
Expand Down
4 changes: 4 additions & 0 deletions pkg/converter/flair/conversion/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ import (
"github.com/nlpodyssey/spago/mat"
)

type PyAttributeGettable interface {
PyGetAttribute(name string) (value any, exists bool, err error)
}

func AssertType[T any](v any) (vt T, err error) {
var ok bool
vt, ok = v.(T)
Expand Down
2 changes: 1 addition & 1 deletion pkg/converter/flair/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func (conv *converter[T]) encoderEmbeddingsTokensEncoderWordEmbeddings(st store.
func (conv *converter[T]) encoderEmbeddingsTokensEncoderCharLMVocabulary(forward, backward *convflair.FlairEmbeddings) (*vocabulary.Vocabulary, error) {
d1 := forward.LM.Dictionary
d2 := backward.LM.Dictionary
if !reflect.DeepEqual(d1, d2) {
if !reflect.DeepEqual(d1.Idx2Item, d2.Idx2Item) {
return nil, fmt.Errorf("FlairEmbeddings LM forward/backward dictionaries differ")
}
return vocabulary.New(d1.Idx2Item), nil
Expand Down

0 comments on commit 4da856d

Please sign in to comment.