diff --git a/pkg/converter/flair/conversion/builtins/getattr.go b/pkg/converter/flair/conversion/builtins/getattr.go new file mode 100644 index 0000000..f62e4b1 --- /dev/null +++ b/pkg/converter/flair/conversion/builtins/getattr.go @@ -0,0 +1,35 @@ +package builtins + +import ( + "fmt" + + "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion" +) + +type Getattr struct{} + +func (Getattr) Call(args ...any) (any, error) { + if len(args) != 2 && len(args) != 3 { + return nil, fmt.Errorf("builtins.getattr: want 2 or 3 args, got %d: %#v", len(args), args) + } + + object, ok := args[0].(conversion.PyAttributeGettable) + if !ok { + return nil, fmt.Errorf("builtins.getattr: 1st arg (object) does not satisfy PyAttributeGettable interface: %T", args[0]) + } + + name, ok := args[1].(string) + if !ok { + return nil, fmt.Errorf("builtins.getattr: want 2nd arg (name) to be string, got %T: %#v", args[1], args[1]) + } + + value, exists, err := object.PyGetAttribute(name) + if err != nil { + return nil, fmt.Errorf("builtins.getattr(%#v): PyGetAttribute failed: %w", args, err) + } + + if len(args) == 3 && !exists { + return args[2], nil + } + return value, nil +} diff --git a/pkg/converter/flair/conversion/builtins/int.go b/pkg/converter/flair/conversion/builtins/int.go new file mode 100644 index 0000000..fe0b091 --- /dev/null +++ b/pkg/converter/flair/conversion/builtins/int.go @@ -0,0 +1,3 @@ +package builtins + +type Int struct{} diff --git a/pkg/converter/flair/conversion/collections/defaultdict.go b/pkg/converter/flair/conversion/collections/defaultdict.go new file mode 100644 index 0000000..6e2934f --- /dev/null +++ b/pkg/converter/flair/conversion/collections/defaultdict.go @@ -0,0 +1,25 @@ +package collections + +import ( + "fmt" + + "github.com/nlpodyssey/gopickle/types" +) + +type DefaultDictClass struct{} + +type DefaultDict struct { + *types.Dict + DefaultFactory any +} + +func (DefaultDictClass) Call(args ...any) (any, error) { + if len(args) != 1 { + return nil, fmt.Errorf("DefaultDictClass: want 1 argument, got %d: %#v", len(args), args) + } + + return &DefaultDict{ + Dict: types.NewDict(), + DefaultFactory: args[0], + }, nil +} diff --git a/pkg/converter/flair/conversion/flair/dictionary.go b/pkg/converter/flair/conversion/flair/dictionary.go index ddee76d..acfaa16 100644 --- a/pkg/converter/flair/conversion/flair/dictionary.go +++ b/pkg/converter/flair/conversion/flair/dictionary.go @@ -8,15 +8,17 @@ import ( "fmt" "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion" + "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/collections" "github.com/nlpodyssey/gopickle/types" ) type DictionaryClass struct{} type Dictionary struct { - Item2Idx map[string]int - Idx2Item []string - MultiLabel bool + Item2Idx map[string]int + Idx2Item []string + Item2IdxNotEncoded *collections.DefaultDict + MultiLabel bool } func (DictionaryClass) PyNew(args ...any) (any, error) { @@ -40,6 +42,8 @@ func (d *Dictionary) PyDictSet(k, v any) (err error) { if err == nil { err = d.setIdx2Item(l) } + case "item2idx_not_encoded": + err = conversion.AssignAssertedType(v, &d.Item2IdxNotEncoded) case "multi_label": err = conversion.AssignAssertedType(v, &d.MultiLabel) default: diff --git a/pkg/converter/flair/conversion/flair/flairembeddings.go b/pkg/converter/flair/conversion/flair/flairembeddings.go index c392945..9a2a195 100644 --- a/pkg/converter/flair/conversion/flair/flairembeddings.go +++ b/pkg/converter/flair/conversion/flair/flairembeddings.go @@ -24,6 +24,9 @@ type FlairEmbeddings struct { CharsPerChunk int embeddingLength int PretrainedModelArchiveMap map[string]string + InstanceParameters *types.Dict + WithWhitespace bool // Default: true + TokenizedLM bool // Default: true LM *LanguageModel } @@ -66,6 +69,12 @@ func (f *FlairEmbeddings) PyDictSet(k, v any) (err error) { if err == nil { err = conversion.AssignDictToMap(d, &f.PretrainedModelArchiveMap) } + case "instance_parameters": + err = conversion.AssignAssertedType(v, &f.InstanceParameters) + case "with_whitespace": + err = conversion.AssignAssertedType(v, &f.WithWhitespace) + case "tokenized_lm": + err = conversion.AssignAssertedType(v, &f.TokenizedLM) case "detach", "cache": // TODO default: err = fmt.Errorf("unexpected key with value %#v", v) diff --git a/pkg/converter/flair/conversion/flair/languagemodel.go b/pkg/converter/flair/conversion/flair/languagemodel.go index 2b68cec..40b86cd 100644 --- a/pkg/converter/flair/conversion/flair/languagemodel.go +++ b/pkg/converter/flair/conversion/flair/languagemodel.go @@ -16,13 +16,14 @@ type LanguageModelClass struct{} type LanguageModel struct { torch.Module - Dictionary *Dictionary - IsForwardLm bool - Dropout float64 - HiddenSize int - EmbeddingSize int - NLayers int - NOut int + Dictionary *Dictionary + IsForwardLm bool + Dropout float64 + HiddenSize int + EmbeddingSize int + NLayers int + NOut int + DocumentDelimiter string Encoder *torch.SparseEmbedding Decoder *torch.Linear @@ -69,6 +70,8 @@ func (l *LanguageModel) PyDictSet(k, v any) (err error) { if v != nil { err = fmt.Errorf("only nil is supported, got %T: %#v", v, v) } + case "document_delimiter": + err = conversion.AssignAssertedType(v, &l.DocumentDelimiter) default: err = fmt.Errorf("unexpected key with value %#v", v) } diff --git a/pkg/converter/flair/conversion/flair/unpickling.go b/pkg/converter/flair/conversion/flair/unpickling.go index c0a07c1..6ad304a 100644 --- a/pkg/converter/flair/conversion/flair/unpickling.go +++ b/pkg/converter/flair/conversion/flair/unpickling.go @@ -8,6 +8,8 @@ import ( "fmt" "io" + "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/builtins" + "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/collections" "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/gensim" "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/numpy" "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/torch" @@ -15,11 +17,15 @@ import ( ) var allClasses = map[string]any{ + "builtins.getattr": builtins.Getattr{}, + "builtins.int": builtins.Int{}, + "collections.defaultdict": collections.DefaultDictClass{}, "flair.data.Dictionary": DictionaryClass{}, "flair.embeddings.FlairEmbeddings": FlairEmbeddingsClass{}, "flair.embeddings.StackedEmbeddings": StackedEmbeddingsClass{}, - "flair.embeddings.token.StackedEmbeddings": StackedEmbeddingsClass{}, "flair.embeddings.WordEmbeddings": WordEmbeddingsClass{}, + "flair.embeddings.token.FlairEmbeddings": FlairEmbeddingsClass{}, + "flair.embeddings.token.StackedEmbeddings": StackedEmbeddingsClass{}, "flair.embeddings.token.WordEmbeddings": WordEmbeddingsClass{}, "flair.models.language_model.LanguageModel": LanguageModelClass{}, "gensim.models.keyedvectors.Vocab": gensim.VocabClass{}, @@ -28,11 +34,11 @@ var allClasses = map[string]any{ "numpy.dtype": numpy.DTypeClass{}, "numpy.ndarray": numpy.NDArrayClass{}, "torch._utils._rebuild_parameter": torch.RebuildParameter{}, + "torch.backends.cudnn.rnn.Unserializable": torch.RNNUnserializableClass{}, "torch.nn.modules.dropout.Dropout": torch.DropoutClass{}, "torch.nn.modules.linear.Linear": torch.LinearClass{}, "torch.nn.modules.rnn.LSTM": torch.LSTMClass{}, "torch.nn.modules.sparse.Embedding": torch.SparseEmbeddingClass{}, - "torch.backends.cudnn.rnn.Unserializable": torch.RNNUnserializableClass{}, } func newUnpickler(r io.Reader) pickle.Unpickler { diff --git a/pkg/converter/flair/conversion/flair/wordembeddings.go b/pkg/converter/flair/conversion/flair/wordembeddings.go index d2544dc..1ba0767 100644 --- a/pkg/converter/flair/conversion/flair/wordembeddings.go +++ b/pkg/converter/flair/conversion/flair/wordembeddings.go @@ -10,6 +10,7 @@ import ( "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion" "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/gensim" "github.com/nlpodyssey/cybertron/pkg/converter/flair/conversion/torch" + "github.com/nlpodyssey/gopickle/types" "github.com/nlpodyssey/spago/mat" ) @@ -17,12 +18,13 @@ type WordEmbeddingsClass struct{} type WordEmbeddings struct { TokenEmbeddingsModule - Embeddings string - Name string - StaticEmbeddings bool - Embedding *torch.Embedding - Vocab map[string]int - embeddingLength int + Embeddings string + Name string + StaticEmbeddings bool + Embedding *torch.Embedding + Vocab map[string]int + InstanceParameters *types.Dict + embeddingLength int } var _ TokenEmbeddings = &WordEmbeddings{} @@ -48,6 +50,10 @@ func (w *WordEmbeddings) PyDictSet(k, v any) (err error) { switch k { case "embeddings": err = conversion.AssignAssertedType(v, &w.Embeddings) + case "get_cached_vec": + // present on older models, can be ignored + case "instance_parameters": + err = conversion.AssignAssertedType(v, &w.InstanceParameters) case "name": err = conversion.AssignAssertedType(v, &w.Name) case "static_embeddings": @@ -93,6 +99,17 @@ func (w *WordEmbeddings) setPrecomputedWordEmbeddings(kv *gensim.KeyedVectors) e return nil } +func (w *WordEmbeddings) PyGetAttribute(name string) (value any, exists bool, err error) { + switch name { + case "get_cached_vec": + // this ignores the get_cached_vec method when loading older versions + // it is needed for compatibility reasons + return nil, true, nil + default: + return nil, false, fmt.Errorf("WordEmbeddings: unexpected __getattribute__(%q)", name) + } +} + func (w *WordEmbeddings) LoadStateDictEntry(string, any) error { return fmt.Errorf("WordEmbeddings: loading from state dict entry not implemented") } diff --git a/pkg/converter/flair/conversion/torch/module.go b/pkg/converter/flair/conversion/torch/module.go index c411d1b..d9e987f 100644 --- a/pkg/converter/flair/conversion/torch/module.go +++ b/pkg/converter/flair/conversion/torch/module.go @@ -13,15 +13,16 @@ import ( ) type Module struct { - Training bool - Parameters *types.OrderedDict - Buffers *types.OrderedDict - BackwardHooks *types.OrderedDict - ForwardHooks *types.OrderedDict - ForwardPreHooks *types.OrderedDict - StateDictHooks *types.OrderedDict - LoadStateDictPreHooks *types.OrderedDict - Modules *types.OrderedDict + Training bool + Parameters *types.OrderedDict + Buffers *types.OrderedDict + BackwardHooks *types.OrderedDict + ForwardHooks *types.OrderedDict + ForwardPreHooks *types.OrderedDict + StateDictHooks *types.OrderedDict + LoadStateDictPreHooks *types.OrderedDict + Modules *types.OrderedDict + NonPersistentBuffersSet *types.Set } func GetSubModule[T any](mod Module, name string) (v T, err error) { @@ -70,6 +71,8 @@ func (m *Module) PyDictSet(k, v any) (err error) { err = conversion.AssignAssertedType(v, &m.ForwardPreHooks) case "_state_dict_hooks": err = conversion.AssignAssertedType(v, &m.StateDictHooks) + case "_non_persistent_buffers_set": + err = conversion.AssignAssertedType(v, &m.NonPersistentBuffersSet) case "_load_state_dict_pre_hooks": err = conversion.AssignAssertedType(v, &m.LoadStateDictPreHooks) case "_modules": diff --git a/pkg/converter/flair/conversion/torch/rnn.go b/pkg/converter/flair/conversion/torch/rnn.go index 6323b03..55ce002 100644 --- a/pkg/converter/flair/conversion/torch/rnn.go +++ b/pkg/converter/flair/conversion/torch/rnn.go @@ -42,6 +42,7 @@ type RNNBase struct { Bidirectional bool ProjSize int FlatWeightsNames []string + FlatWeights []*Parameter AllWeights [][]string Parameters map[string]*Parameter } @@ -202,6 +203,18 @@ func (r *RNNBase) PyDictSet(k, v any) (err error) { err = conversion.AssignAssertedType(v, &r.Bidirectional) case "_all_weights": err = r.convertAndSetAllWeights(v) + case "_flat_weights": + var l *types.List + err = conversion.AssignAssertedType(v, &l) + if err == nil { + err = conversion.AssignListToSlice(l, &r.FlatWeights) + } + case "_flat_weights_names": + var l *types.List + err = conversion.AssignAssertedType(v, &l) + if err == nil { + err = conversion.AssignListToSlice(l, &r.FlatWeightsNames) + } case "_data_ptrs", "_param_buf_size": default: err = fmt.Errorf("unexpected key with value %#v", v) diff --git a/pkg/converter/flair/conversion/utils.go b/pkg/converter/flair/conversion/utils.go index 7d0f448..559347d 100644 --- a/pkg/converter/flair/conversion/utils.go +++ b/pkg/converter/flair/conversion/utils.go @@ -12,6 +12,10 @@ import ( "github.com/nlpodyssey/spago/mat" ) +type PyAttributeGettable interface { + PyGetAttribute(name string) (value any, exists bool, err error) +} + func AssertType[T any](v any) (vt T, err error) { var ok bool vt, ok = v.(T) diff --git a/pkg/converter/flair/convert.go b/pkg/converter/flair/convert.go index c6980aa..7c40243 100644 --- a/pkg/converter/flair/convert.go +++ b/pkg/converter/flair/convert.go @@ -271,7 +271,7 @@ func (conv *converter[T]) encoderEmbeddingsTokensEncoderWordEmbeddings(st store. func (conv *converter[T]) encoderEmbeddingsTokensEncoderCharLMVocabulary(forward, backward *convflair.FlairEmbeddings) (*vocabulary.Vocabulary, error) { d1 := forward.LM.Dictionary d2 := backward.LM.Dictionary - if !reflect.DeepEqual(d1, d2) { + if !reflect.DeepEqual(d1.Idx2Item, d2.Idx2Item) { return nil, fmt.Errorf("FlairEmbeddings LM forward/backward dictionaries differ") } return vocabulary.New(d1.Idx2Item), nil