forked from datamics/flair
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
66 lines (54 loc) · 1.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import List
import flair.datasets
from flair.data import Corpus
from flair.embeddings import (
TokenEmbeddings,
WordEmbeddings,
StackedEmbeddings,
FlairEmbeddings,
CharacterEmbeddings,
)
from flair.training_utils import EvaluationMetric
from flair.visual.training_curves import Plotter
# 1. get the corpus
corpus: Corpus = flair.datasets.UD_ENGLISH()
print(corpus)
# 2. what tag do we want to predict?
tag_type = "upos"
# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)
print(tag_dictionary.idx2item)
# initialize embeddings
embedding_types: List[TokenEmbeddings] = [
WordEmbeddings("glove"),
# comment in this line to use character embeddings
# CharacterEmbeddings(),
# comment in these lines to use contextual string embeddings
#
# FlairEmbeddings('news-forward'),
#
# FlairEmbeddings('news-backward'),
]
embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)
# initialize sequence tagger
from flair.models import SequenceTagger
tagger: SequenceTagger = SequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type=tag_type,
use_crf=True,
)
# initialize trainer
from flair.trainers import ModelTrainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.train(
"resources/taggers/example-ner",
learning_rate=0.1,
mini_batch_size=32,
max_epochs=20,
shuffle=False,
)
plotter = Plotter()
plotter.plot_training_curves("resources/taggers/example-ner/loss.tsv")
plotter.plot_weights("resources/taggers/example-ner/weights.txt")