-
Notifications
You must be signed in to change notification settings - Fork 0
/
llm_router.py
108 lines (97 loc) · 3.83 KB
/
llm_router.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
# llm_router.py
import os
import torch
import requests
import json
import openai
from production_transformer import ProductionTransformer
class LLMRouter:
"""
Routes inference to local or remote LLM backends:
- local (ProductionTransformer placeholder)
- OpenAI
- Together.ai
- Ollama
"""
def __init__(self, local_vocab_size=10000, local_ckpt="transformer_checkpoint.pt"):
self.local_ckpt = local_ckpt
self.local_vocab_size = local_vocab_size
self.local_model = None
# Env variables for external APIs
self.openai_api_key = os.getenv("OPENAI_API_KEY", "")
self.together_api_key = os.getenv("TOGETHER_API_KEY", "")
self.ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11411")
openai.api_key = self.openai_api_key
def load_local_model(self):
if self.local_model is None:
model = ProductionTransformer(vocab_size=self.local_vocab_size)
if os.path.exists(self.local_ckpt):
model.load_state_dict(torch.load(self.local_ckpt, map_location='cpu'))
model.eval()
self.local_model = model
return self.local_model
def generate(self, prompt, backend="local"):
if backend == "local":
return self._generate_local(prompt)
elif backend == "openai":
return self._generate_openai(prompt)
elif backend == "together":
return self._generate_together(prompt)
elif backend == "ollama":
return self._generate_ollama(prompt)
else:
return f"[Error] Unsupported backend: {backend}"
def _generate_local(self, prompt):
model = self.load_local_model()
dummy_input = torch.randint(0, self.local_vocab_size, (1, 20))
with torch.no_grad():
logits = model(dummy_input)
return "[Local Model Placeholder Response]"
def _generate_openai(self, prompt):
if not self.openai_api_key:
return "[OpenAI API key not provided]"
try:
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
max_tokens=100,
temperature=0.7
)
return response['choices'][0]['text'].strip()
except Exception as e:
return f"[OpenAI Error: {str(e)}]"
def _generate_together(self, prompt):
if not self.together_api_key:
return "[Together.ai API key not provided]"
try:
endpoint = "https://api.together.ai/generate"
headers = {
"Authorization": f"Bearer {self.together_api_key}",
"Content-Type": "application/json"
}
payload = {
"model": "together/galactica-6.7b",
"prompt": prompt,
"max_tokens": 100
}
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload))
if resp.status_code == 200:
data = resp.json()
return data.get("text", "[No response from Together.ai]")
else:
return f"[Together.ai Error: {resp.status_code} {resp.text}]"
except Exception as e:
return f"[Together.ai Error: {str(e)}]"
def _generate_ollama(self, prompt):
try:
endpoint = f"{self.ollama_endpoint}/generate"
payload = {"prompt": prompt, "model": "llama2"}
resp = requests.post(endpoint, json=payload)
if resp.status_code == 200:
data = resp.json()
return data.get("generated_text", "[No Ollama response]")
else:
return f"[Ollama Error: {resp.status_code} {resp.text}]"
except Exception as e:
return f"[Ollama Error: {str(e)}]"