Skip to content

Commit

Permalink
Switch from WebSockets to SSE for LMSYSBot
Browse files Browse the repository at this point in the history
  • Loading branch information
fathmike9123 authored and sunner committed Mar 7, 2024
1 parent 0af017c commit c6d653c
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 60 deletions.
128 changes: 72 additions & 56 deletions src/bots/huggingface/GradioBot.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import axios from "axios";
import WebSocketAsPromised from "websocket-as-promised";
import Bot from "@/bots/Bot";
import i18n from "@/i18n";
import { SSE } from "sse.js";

export default class GradioBot extends Bot {
static _brandId = "gradio"; // Brand id of the bot, should be unique. Used in i18n.
Expand All @@ -11,6 +11,7 @@ export default class GradioBot extends Bot {
static _fnIndexes = [0]; // Indexes of the APIs to call in order. Sniffer it by devtools.

config = {};
eventListeners = new Map();

constructor() {
super();
Expand Down Expand Up @@ -68,94 +69,103 @@ export default class GradioBot extends Bot {
async _sendFnIndex(fn_index, prompt, onUpdateResponse, callbackParam) {
const config = this.config;
const session_hash = await this.getChatContext();

const joinUrl = new URL(config.root + config.path + "/queue/join");
const data = this.makeData(fn_index, prompt);

const streamData = {
data,
event_data: null,
fn_index,
session_hash,
trigger_id: this._triggerId ?? 0,
};
const streamContext = await axios.post(joinUrl.toString(), streamData);

if (streamContext.status !== 200 || !streamContext.data.event_id) {
return Promise.reject(
i18n.global.t("error.failedConnectUrl", { url: joinUrl }),
);
}

return new Promise((resolve, reject) => {
try {
const url = new URL(config.root + config.path + "/queue/join");
url.protocol = url.protocol === "https:" ? "wss:" : "ws:";

const data = this.makeData(fn_index, prompt);

const wsp = new WebSocketAsPromised(url.toString(), {
packMessage: (data) => {
return JSON.stringify(data);
},
unpackMessage: (data) => {
return JSON.parse(data);
},
});

wsp.onUnpackedMessage.addListener(async (event) => {
if (event.msg === "send_hash") {
wsp.sendPacked({ fn_index, session_hash });
} else if (event.msg === "send_data") {
// Requested to send data
wsp.sendPacked({
data,
event_data: null,
fn_index,
session_hash,
});
} else if (event.msg === "estimation") {
if (event.rank > 0) {
const dataUrl = new URL(config.root + config.path + "/queue/data");
dataUrl.searchParams.set("session_hash", session_hash);

const source = new SSE(dataUrl.toString());

const onMessageEventHandler = (event) => {
const data = JSON.parse(event.data);

if (data.msg === "estimation") {
if (data.rank > 0) {
// Waiting in queue
event.rank_eta = Math.floor(event.rank_eta);
data.rank_eta = Math.floor(data.rank_eta);
onUpdateResponse(callbackParam, {
content: i18n.global.t("gradio.waiting", { ...event }),
content: i18n.global.t("gradio.waiting", { ...data }),
done: false,
});
}
} else if (event.msg === "process_generating") {
} else if (data.msg === "process_generating") {
// Generating data
if (event.success && event.output.data) {
if (data.success && data.output.data) {
onUpdateResponse(callbackParam, {
content: this.parseData(fn_index, event.output.data),
content: this.parseData(fn_index, data.output.data),
done: false,
});
} else {
reject(new Error(event.output.error));
reject(new Error(data.output.error));
}
} else if (event.msg === "process_completed") {
} else if (data.msg === "process_completed") {
// Done
if (event.success && event.output.data) {
if (data.success && data.output.data) {
if (
typeof event.output.data[2] !== "string" ||
event.output.data[2] === ""
typeof data.output.data[2] !== "string" ||
data.output.data[2] === ""
) {
onUpdateResponse(callbackParam, {
content: this.parseData(fn_index, event.output.data),
content: this.parseData(fn_index, data.output.data),
done: fn_index == this.constructor._fnIndexes.slice(-1), // Only the last one is done
});
} else {
const errorMsg = this.parseError(event.output.data[2]);
const errorMsg = this.parseError(data.output.data[2]);
reject(new Error(errorMsg));
}
} else {
reject(new Error(event.output.error));
}
wsp.removeAllListeners();
wsp.close();

this.removeAllEventListeners(source);
source.close();
resolve();
} else if (event.msg === "queue_full") {
} else if (data.msg === "queue_full") {
reject(i18n.global.t("gradio.queueFull"));
}
});
};

wsp.onClose.addListener((event) => {
console.log("WebSocket closed:", event);
wsp.removeAllListeners();
wsp.close();
const onAbortEventHandler = (event) => {
console.log("Server-Sent Event closed:", event);
this.removeAllEventListeners(source);
source.close();
reject(new Error(i18n.global.t("error.closedByServer")));
});
};

wsp.onError.addListener((event) => {
wsp.removeAllListeners();
wsp.close();
const onErrorEventHandler = (event) => {
this.removeAllEventListeners(source);
source.close();
reject(
i18n.global.t("error.failedConnectUrl", { url: event.target.url }),
);
});
};

this.eventListeners.set("message", onMessageEventHandler);
this.eventListeners.set("error", onErrorEventHandler);
this.eventListeners.set("abort", onAbortEventHandler);

for (const [eventName, eventHandler] of this.eventListeners) {
source.addEventListener(eventName, eventHandler);
}

wsp.open();
source.stream();
} catch (error) {
reject(error);
}
Expand All @@ -175,4 +185,10 @@ export default class GradioBot extends Bot {
parseError(errorMsg) {
return errorMsg;
}

removeAllEventListeners(source) {
for (const [eventName, eventHandler] of this.eventListeners) {
source.removeEventListener(eventName, eventHandler);
}
}
}
12 changes: 11 additions & 1 deletion src/bots/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,17 @@ const all = [
YouChatBot.getInstance(),
];

const disabled = ["ClaudeBot", "AlpacaBot", "HuggingChatBot", "Falcon180bBot"];
const disabled = [
"ClaudeBot",
"AlpacaBot",
"HuggingChatBot",
"Falcon180bBot",
"ChatGLM6bBot",
"ChatGLM36bBot",
"CodeLlamaBot",
"Vicuna7bBot",
"Wizardlm13bBot",
];

if (process.env.NODE_ENV !== "production") {
all.push(DevBot.getInstance());
Expand Down
13 changes: 10 additions & 3 deletions src/bots/lmsys/LMSYSBot.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ export default class LMSYSBot extends GradioBot {
static _outputFormat = "html"; // "markdown" or "html"
static _lock = new AsyncLock(); // Send requests in queue to save LMSYS

static _fnIndexes = [39, 40]; // Indexes of the APIs to call in order. Sniffer it by devtools.
static _fnIndexes = [41, 42]; // Indexes of the APIs to call in order. Sniffer it by devtools.
_triggerId = 93; // From devtools

constructor() {
super();
Expand All @@ -28,7 +29,7 @@ export default class LMSYSBot extends GradioBot {
makeData(fn_index, prompt) {
let r = null;
if (fn_index === this.constructor._fnIndexes[0]) {
r = [null, this.constructor._model, prompt];
r = [null, this.constructor._model, prompt, null];
} else if (fn_index === this.constructor._fnIndexes[1]) {
r = [null, 0.7, 1, 512];
} else if (fn_index === 43) {
Expand All @@ -40,7 +41,13 @@ export default class LMSYSBot extends GradioBot {
parseData(fn_index, data) {
let r = undefined;
if (fn_index === this.constructor._fnIndexes[1]) {
r = data[1][data[1].length - 1][1];
const dataOne = data[1];

if (dataOne.length > 0) {
const dataTwo = dataOne[dataOne.length - 1];
const dataThree = dataTwo[1];
r = dataThree;
}
}
if (!r) r = ""; // Sometimes the result from data[] is null
return r;
Expand Down

0 comments on commit c6d653c

Please sign in to comment.