-
Notifications
You must be signed in to change notification settings - Fork 1
/
__init__.py
158 lines (123 loc) · 5.23 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from enum import Enum, unique
from .base import VectorWorkload, VectorWorkloadSequence, SingleVectorWorkloadSequence
from vsb import logger
@unique
class Workload(Enum):
"""Set of supported workloads, the value is the string used to
specify via --workload=
"""
Mnist = "mnist"
MnistTest = "mnist-test"
Nq768 = "nq768"
Nq768Test = "nq768-test"
YFCC = "yfcc-10M"
YFCCTest = "yfcc-test"
Cohere768 = "cohere768"
Cohere768Test = "cohere768-test"
MsMarcoV2Ada = "msmarco-v2-ada"
MsMarcoV2AdaTest = "msmarco-v2-ada-test"
Synthetic = "synthetic"
SyntheticProportional = "synthetic-proportional"
def build(self, **kwargs) -> VectorWorkload:
"""Construct an instance of VectorWorkload based on the value of the enum."""
cls = self._get_class()
return cls(name=self.value, **kwargs)
def _get_class(self) -> type[VectorWorkload]:
"""Return the VectorWorkload class to use, based on the value of the enum"""
match self:
case Workload.Synthetic:
from .synthetic_workload.synthetic_workload import SyntheticWorkload
return SyntheticWorkload
case Workload.SyntheticProportional:
from .synthetic_workload.synthetic_workload import (
SyntheticProportionalWorkload,
)
return SyntheticProportionalWorkload
case Workload.Mnist:
from .mnist.mnist import Mnist
return Mnist
case Workload.MnistTest:
from .mnist.mnist import MnistTest
return MnistTest
case Workload.Nq768:
from .nq_768_tasb.nq_768_tasb import Nq768Tasb
return Nq768Tasb
case Workload.Nq768Test:
from .nq_768_tasb.nq_768_tasb import Nq768TasbTest
return Nq768TasbTest
case Workload.YFCC:
from .yfcc.yfcc import YFCC
return YFCC
case Workload.YFCCTest:
from .yfcc.yfcc import YFCCTest
return YFCCTest
case Workload.Cohere768:
from .cohere_768.cohere_768 import Cohere768
return Cohere768
case Workload.Cohere768Test:
from .cohere_768.cohere_768 import Cohere768Test
return Cohere768Test
case Workload.MsMarcoV2Ada:
from .msmarco_v2_ada.msmarco_v2_ada import MsMarcoV2Ada
return MsMarcoV2Ada
case Workload.MsMarcoV2AdaTest:
from .msmarco_v2_ada.msmarco_v2_ada import MsMarcoV2AdaTest
return MsMarcoV2AdaTest
def describe(self) -> tuple[str, int, int, str, int]:
"""Return a tuple with attributes of the workload: name, dataset size, dimensionality, distance metric, and query count."""
cls = self._get_class()
return (
self.value,
cls.record_count(),
cls.dimensions(),
cls.metric().value,
cls.request_count(),
)
@unique
class WorkloadSequence(Enum):
"""Set of supported workload sequences, the value is the string used to
specify via --workload=.
"""
MnistSplit = "mnist-split"
MnistDoubleTest = "mnist-double-test"
Nq768Split = "nq768-split"
Cohere768Split = "cohere768-split"
YFCCSplit = "yfcc-split"
SyntheticRunbook = "synthetic-runbook"
def build(self, **kwargs) -> VectorWorkloadSequence:
"""Construct an instance of VectorWorkload based on the value of the enum."""
cls = self._get_class()
return cls(self.value, **kwargs)
def _get_class(self) -> type[VectorWorkloadSequence]:
"""Return the VectorWorkloadSequence class to use, based on the value of the enum"""
match self:
case WorkloadSequence.SyntheticRunbook:
from .synthetic_workload.synthetic_workload import SyntheticRunbook
return SyntheticRunbook
case WorkloadSequence.MnistSplit:
from .mnist.mnist import MnistSplit
return MnistSplit
case WorkloadSequence.MnistDoubleTest:
from .mnist.mnist import MnistDoubleTest
return MnistDoubleTest
case WorkloadSequence.Nq768Split:
from .nq_768_tasb.nq_768_tasb import Nq768TasbSplit
return Nq768TasbSplit
case WorkloadSequence.Cohere768Split:
from .cohere_768.cohere_768 import Cohere768Split
return Cohere768Split
case WorkloadSequence.YFCCSplit:
from .yfcc.yfcc import YFCCSplit
return YFCCSplit
pass
def build_workload_sequence(name: str, **kwargs) -> VectorWorkloadSequence:
"""Takes either a Workload or WorkloadSequence name and returns the corresponding
WorkloadSequence. Workloads will be wrapped into single-element WorkloadSequences.
"""
logger.debug(f"Building workload sequence {name}")
try:
return WorkloadSequence(name).build(**kwargs)
except ValueError:
# Try to build a Workload.
workload = Workload(name).build(**kwargs)
return SingleVectorWorkloadSequence(name, workload)