-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_index.py
78 lines (63 loc) · 2.35 KB
/
run_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Indexing script.
Example usage:
python run_index.py --config-file ./examples/index/indexer_config.yaml
"""
# Remove warnings
import torchvision
torchvision.disable_beta_transforms_warning()
import warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*TypedStorage is deprecated.*")
warnings.filterwarnings("ignore", category=UserWarning, message="BertForMaskedLM has generative capabilities.*")
import os
import argparse
from src.mmore.utils import load_config
from src.mmore.type import MultimodalSample
from src.mmore.index.indexer import IndexerConfig, Indexer
from dataclasses import dataclass, field
import json
from dotenv import load_dotenv
load_dotenv()
# Set up logging
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Global logging configuration
#logging.basicConfig(format='%(asctime)s: %(message)s')
#logging.basicConfig(format='%(message)s')
logging.basicConfig(format='[INDEX 🗂️ ] %(message)s', level=logging.INFO)
# Suppress overly verbose logs from third-party libraries
logging.getLogger("transformers").setLevel(logging.CRITICAL)
@dataclass
class IndexerRunConfig:
documents_path: str
indexer: IndexerConfig
collection_name: str = 'my_docs'
batch_size: int = 64
def load_results(path: str, file_type: str = None):
# Load the results computed and saved by 'run_process.py'
results = []
logger.info(f"Loading results from {path}")
with open(path + '/merged/merged_results.jsonl', "rb") as f:
for line in f:
results.append(MultimodalSample.from_dict(json.loads(line)))
logger.debug(f"Loaded {len(results)} results")
return results
def get_args():
# Create argument parser
parser = argparse.ArgumentParser(description='Index files for specified documents')
parser.add_argument('--config-file', type=str, required=True, help='Path to a config file.')
# Parse the arguments
return parser.parse_args()
if __name__ == "__main__":
# get script args
args = get_args()
# Load the config file
config = load_config(args.config_file, IndexerRunConfig)
logger.info("Creating the indexer...")
indexer = Indexer.from_documents(
config=config.indexer,
documents=load_results(config.documents_path),
collection_name=config.collection_name,
batch_size=config.batch_size
)