Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Gemini Model Support with Search Grounding #292

Merged
merged 12 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@
"turbo": "latest",
"typescript": "^5.2.2"
},
"dependencies": {
"@ai-sdk/google": "^1.0.12"
},
"packageManager": "pnpm@8.15.4"
}
2 changes: 1 addition & 1 deletion packages/plugin/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export class FileOrganizerSettings {
useVaultTitles = true;
showLocalLLMInChat = false;
customFolderInstructions = "";
selectedModel: "gpt-4o" | "llama3.2" = "gpt-4o";
selectedModel: "gpt-4o" | "llama3.2" | "gemini-2.0-flash-exp" = "gpt-4o";
customModelName = "llama3.2";
tagScoreThreshold = 70;
formatBehavior: "override" | "newFile" = "override";
Expand Down
12 changes: 6 additions & 6 deletions packages/plugin/views/ai-chat/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export const ChatComponent: React.FC<ChatComponentProps> = ({
currentDatetime: window.moment().format("YYYY-MM-DDTHH:mm:ssZ"),
enableScreenpipe: plugin.settings.enableScreenpipe,
newUnifiedContext: contextString,
model: plugin.settings.selectedModel, // Pass selected model to server
};

const {
Expand All @@ -98,12 +99,11 @@ export const ChatComponent: React.FC<ChatComponentProps> = ({
fetch: async (url, options) => {
logMessage(plugin.settings.showLocalLLMInChat, "showLocalLLMInChat");
logMessage(selectedModel, "selectedModel");
// local llm disabled or using gpt-4o
if (!plugin.settings.showLocalLLMInChat) {
// return normal server fetch
return fetch(url, options);
}
if (selectedModel === "gpt-4o") {
// Handle different model types
if (!plugin.settings.showLocalLLMInChat ||
selectedModel === "gpt-4o" ||
selectedModel === "gemini-2.0-flash-exp") {
// Use server fetch for non-local models
return fetch(url, options);
}

Expand Down
8 changes: 7 additions & 1 deletion packages/plugin/views/ai-chat/model-selector.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const ModelSelector: React.FC<ModelSelectorProps> = ({
return;
}
onModelSelect(model);
if (model === "gpt-4o" || model === "llama3.2") {
if (model === "gpt-4o" || model === "llama3.2" || model === "gemini-2.0-flash-exp") {
plugin.settings.selectedModel = model;
}
await plugin.saveSettings();
Expand Down Expand Up @@ -72,6 +72,12 @@ export const ModelSelector: React.FC<ModelSelectorProps> = ({
>
gpt-4o
</div>
<div
onClick={() => handleModelSelect("gemini-2.0-flash-exp")}
className="cursor-pointer block w-full text-left px-4 py-2 text-sm text-[--text-normal] hover:bg-[--background-modifier-hover]"
>
gemini-2.0-flash-exp
</div>
{isCustomizing ? (
<div className="px-4 py-2">
<input
Expand Down
2 changes: 1 addition & 1 deletion packages/plugin/views/ai-chat/types.ts
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export type ModelType = "gpt-4o" | "llama3.2" | string;
export type ModelType = "gpt-4o" | "llama3.2" | "gemini-2.0-flash-exp" | string;
13 changes: 9 additions & 4 deletions packages/web/app/api/(newai)/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,13 @@ import { getModel } from "@/lib/models";
import { getChatSystemPrompt } from "@/lib/prompts/chat-prompt";

export const maxDuration = 60;
const MODEL_NAME = process.env.MODEL_NAME;

const settingsSchema = z.object({
renameInstructions: z.string().optional(),
customFolderInstructions: z.string().optional(),
imageInstructions: z.string().optional(),
});

export async function POST(req: NextRequest) {
console.log("Chat using model:", MODEL_NAME);
const model = getModel(MODEL_NAME);
try {
const { userId } = await handleAuthorization(req);
const {
Expand All @@ -27,8 +23,17 @@ export async function POST(req: NextRequest) {
enableScreenpipe,
currentDatetime,
unifiedContext: oldUnifiedContext,
model: bodyModel,
} = await req.json();

const chosenModelName =
bodyModel ??
process.env.CHAT_MODEL ??
process.env.MODEL_NAME ??
"gemini-2.0-flash-exp";
console.log("Chat using model:", chosenModelName);
const model = getModel(chosenModelName);

// if oldunified context do what is below if not just return newunified context
const contextString =
newUnifiedContext ||
Expand Down
12 changes: 9 additions & 3 deletions packages/web/lib/models.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import { anthropic } from "@ai-sdk/anthropic";
import { openai } from "@ai-sdk/openai";
import { google } from "@ai-sdk/google";

const DEFAULT_MODEL = "gpt-4o";

const models = {
"gpt-4o": openai("gpt-4o"),
"gpt-4o-2024-08-06": openai("gpt-4o-2024-08-06", ),
"gpt-4o-2024-08-06": openai("gpt-4o-2024-08-06"),
"gpt-4o-mini": openai("gpt-4o-mini"),
"claude-3-5-sonnet-20240620": anthropic("claude-3-5-sonnet-20240620"),
"claude-3-5-sonnet-20241022": anthropic("claude-3-5-sonnet-20241022"),
"claude-3-5-haiku-20241022": anthropic("claude-3-5-haiku-20241022"),
"gemini-2.0-flash-exp": google("gemini-2.0-flash-exp", {
useSearchGrounding: true,
}),
};

export const getModel = (name: string) => {
if (!models[name]) {
console.log(`Model ${name} not found`);
console.log(`Defaulting to gpt-4o`);
return models["gpt-4o"];
console.log(`Defaulting to ${DEFAULT_MODEL}`);
return models[DEFAULT_MODEL];
}
console.log(`Using model ${name}`);

Expand Down
Loading