-
Notifications
You must be signed in to change notification settings - Fork 85
/
BertM.h
48 lines (38 loc) · 1.43 KB
/
BertM.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#ifndef CUBERT_BERTMGPU_H
#define CUBERT_BERTMGPU_H
#include <vector>
#include <string>
#include <atomic>
#include <mutex>
#include "cuBERT.h"
#include "cuBERT/Bert.h"
#include "cuBERT/tensorflow/Graph.h"
namespace cuBERT {
template <typename T>
class BertM {
public:
explicit BertM(const char *model_file,
size_t max_batch_size,
size_t seq_length,
size_t num_hidden_layers = 12,
size_t num_attention_heads = 12);
virtual ~BertM();
unsigned int compute(size_t batch_size,
int *input_ids, int8_t *input_mask, int8_t *segment_ids,
T *output,
cuBERT_OutputType output_type = cuBERT_LOGITS);
// output_to_float = true:
// for half model, the output is always float, the method will convert half to float;
// for float model, this flag is not used.
unsigned int compute(size_t batch_size,
int *input_ids, int8_t *input_mask, int8_t *segment_ids,
cuBERT_Output *output, bool output_to_float = false);
size_t seq_length;
private:
Graph<T> graph;
std::vector<Bert<T> *> bert_instances;
std::vector<std::mutex *> mutex_instances;
std::atomic<uint8_t> rr;
};
}
#endif //CUBERT_BERTMGPU_H