From 03a2db718d3866dea7c89cac7961f426a103566e Mon Sep 17 00:00:00 2001 From: Rajaniraiyn R Date: Tue, 19 Nov 2024 18:40:53 +0530 Subject: [PATCH 1/2] add openai classifier --- .../src/classifiers/openAIClassifier.ts | 138 ++++++++++++++ typescript/src/index.ts | 1 + .../classifiers/OpenAIClassifier.test.ts | 178 ++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 typescript/src/classifiers/openAIClassifier.ts create mode 100644 typescript/tests/classifiers/OpenAIClassifier.test.ts diff --git a/typescript/src/classifiers/openAIClassifier.ts b/typescript/src/classifiers/openAIClassifier.ts new file mode 100644 index 00000000..7990b321 --- /dev/null +++ b/typescript/src/classifiers/openAIClassifier.ts @@ -0,0 +1,138 @@ +import OpenAI from "openai"; +import { + ConversationMessage, + OPENAI_MODEL_ID_GPT_O_MINI +} from "../types"; +import { isClassifierToolInput } from "../utils/helpers"; +import { Logger } from "../utils/logger"; +import { Classifier, ClassifierResult } from "./classifier"; + +export interface OpenAIClassifierOptions { + // Optional: The ID of the OpenAI model to use for classification + // If not provided, a default model may be used + modelId?: string; + + // Optional: Configuration for the inference process + inferenceConfig?: { + // Maximum number of tokens to generate in the response + maxTokens?: number; + + // Controls randomness in output generation + temperature?: number; + + // Controls diversity of output via nucleus sampling + topP?: number; + + // Array of sequences that will stop the model from generating further tokens + stopSequences?: string[]; + }; + + // The API key for authenticating with OpenAI's services + apiKey: string; +} + +export class OpenAIClassifier extends Classifier { + private client: OpenAI; + protected inferenceConfig: { + maxTokens?: number; + temperature?: number; + topP?: number; + stopSequences?: string[]; + }; + + private tools: OpenAI.ChatCompletionTool[] = [ + { + type: "function", + function: { + name: 'analyzePrompt', + description: 'Analyze the user input and provide structured output', + parameters: { + type: 'object', + properties: { + userinput: { + type: 'string', + description: 'The original user input', + }, + selected_agent: { + type: 'string', + description: 'The name of the selected agent', + }, + confidence: { + type: 'number', + description: 'Confidence level between 0 and 1', + }, + }, + required: ['userinput', 'selected_agent', 'confidence'], + }, + }, + }, + ]; + + constructor(options: OpenAIClassifierOptions) { + super(); + + if (!options.apiKey) { + throw new Error("OpenAI API key is required"); + } + this.client = new OpenAI({ apiKey: options.apiKey }); + this.modelId = options.modelId || OPENAI_MODEL_ID_GPT_O_MINI; + + const defaultMaxTokens = 1000; + this.inferenceConfig = { + maxTokens: options.inferenceConfig?.maxTokens ?? defaultMaxTokens, + temperature: options.inferenceConfig?.temperature, + topP: options.inferenceConfig?.topP, + stopSequences: options.inferenceConfig?.stopSequences, + }; + } + + async processRequest( + inputText: string, + chatHistory: ConversationMessage[] + ): Promise { + const messages: OpenAI.ChatCompletionMessageParam[] = [ + { + role: 'system', + content: this.systemPrompt + }, + { + role: 'user', + content: inputText + } + ]; + + try { + const response = await this.client.chat.completions.create({ + model: this.modelId, + messages: messages, + max_tokens: this.inferenceConfig.maxTokens, + temperature: this.inferenceConfig.temperature, + top_p: this.inferenceConfig.topP, + tools: this.tools, + tool_choice: { type: "function", function: { name: "analyzePrompt" } } + }); + + const toolCall = response.choices[0]?.message?.tool_calls?.[0]; + + if (!toolCall || toolCall.function.name !== "analyzePrompt") { + throw new Error("No valid tool call found in the response"); + } + + const toolInput = JSON.parse(toolCall.function.arguments); + + if (!isClassifierToolInput(toolInput)) { + throw new Error("Tool input does not match expected structure"); + } + + const intentClassifierResult: ClassifierResult = { + selectedAgent: this.getAgentById(toolInput.selected_agent), + confidence: parseFloat(toolInput.confidence), + }; + return intentClassifierResult; + + } catch (error) { + Logger.logger.error("Error processing request:", error); + throw error; + } + } +} \ No newline at end of file diff --git a/typescript/src/index.ts b/typescript/src/index.ts index 894bd667..81265ddd 100644 --- a/typescript/src/index.ts +++ b/typescript/src/index.ts @@ -12,6 +12,7 @@ export { AgentResponse } from './agents/agent'; export { BedrockClassifier, BedrockClassifierOptions } from './classifiers/bedrockClassifier'; export { AnthropicClassifier, AnthropicClassifierOptions } from './classifiers/anthropicClassifier'; +export { OpenAIClassifier, OpenAIClassifierOptions } from "./classifiers/openAIClassifier" export { Retriever } from './retrievers/retriever'; export { AmazonKnowledgeBasesRetriever, AmazonKnowledgeBasesRetrieverOptions } from './retrievers/AmazonKBRetriever'; diff --git a/typescript/tests/classifiers/OpenAIClassifier.test.ts b/typescript/tests/classifiers/OpenAIClassifier.test.ts new file mode 100644 index 00000000..46a759f6 --- /dev/null +++ b/typescript/tests/classifiers/OpenAIClassifier.test.ts @@ -0,0 +1,178 @@ +import { OpenAIClassifier, OpenAIClassifierOptions } from '../../src/classifiers/openAIClassifier'; +import OpenAI from 'openai'; +import { ConversationMessage, OPENAI_MODEL_ID_GPT_O_MINI, ParticipantRole } from "../../src/types"; +import { MockAgent } from '../mock/mockAgent'; + +// Mock the OpenAI module +jest.mock('openai'); + +describe('OpenAIClassifier', () => { + let classifier: OpenAIClassifier; + let mockCreateCompletion: jest.Mock; + + const defaultOptions: OpenAIClassifierOptions = { + apiKey: 'test-api-key', + }; + + beforeEach(() => { + // Create a mock for the create method + mockCreateCompletion = jest.fn(); + + // Mock the OpenAI constructor and chat.completions.create method + (OpenAI as jest.MockedClass).mockImplementation(() => ({ + chat: { + completions: { + create: mockCreateCompletion, + }, + }, + } as unknown as OpenAI)); + + classifier = new OpenAIClassifier(defaultOptions); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + describe('constructor', () => { + it('should create an instance with default options', () => { + expect(classifier).toBeInstanceOf(OpenAIClassifier); + expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-api-key' }); + }); + + it('should use custom model ID if provided', () => { + const customOptions: OpenAIClassifierOptions = { + ...defaultOptions, + modelId: 'custom-model-id', + }; + const customClassifier = new OpenAIClassifier(customOptions); + expect(customClassifier['modelId']).toBe('custom-model-id'); + }); + + it('should use default model ID if not provided', () => { + expect(classifier['modelId']).toBe(OPENAI_MODEL_ID_GPT_O_MINI); + }); + + it('should set inference config with custom values', () => { + const customOptions: OpenAIClassifierOptions = { + ...defaultOptions, + inferenceConfig: { + maxTokens: 500, + temperature: 0.7, + topP: 0.9, + stopSequences: ['STOP'], + }, + }; + const customClassifier = new OpenAIClassifier(customOptions); + expect(customClassifier['inferenceConfig']).toEqual(customOptions.inferenceConfig); + }); + + it('should throw an error if API key is not provided', () => { + expect(() => new OpenAIClassifier({ apiKey: '' })).toThrow('OpenAI API key is required'); + }); + }); + + describe('processRequest', () => { + const inputText = 'Hello, how are you?'; + const chatHistory: ConversationMessage[] = []; + + it('should process request successfully', async () => { + const mockResponse = { + choices: [{ + message: { + tool_calls: [{ + function: { + name: 'analyzePrompt', + arguments: JSON.stringify({ + userinput: inputText, + selected_agent: 'test-agent', + confidence: 0.95, + }), + }, + }], + }, + }], + }; + + const mockAgent = { + 'test-agent': new MockAgent({ + name: "test-agent", + description: 'A tech support agent', + }) + }; + + classifier.setAgents(mockAgent); + + mockCreateCompletion.mockResolvedValue(mockResponse); + + const result = await classifier.processRequest(inputText, chatHistory); + + expect(mockCreateCompletion).toHaveBeenCalledWith({ + model: OPENAI_MODEL_ID_GPT_O_MINI, + max_tokens: 1000, + messages: [ + { + role: 'system', + content: classifier['systemPrompt'], // Use the actual system prompt + }, + { + role: 'user', + content: inputText + } + ], + temperature: undefined, + top_p: undefined, + tools: classifier['tools'], // Use the actual tools array + tool_choice: { type: "function", function: { name: "analyzePrompt" } } + }); + + expect(result).toEqual({ + selectedAgent: expect.any(MockAgent), + confidence: 0.95, + }); + }); + + it('should throw an error if no tool calls are found in the response', async () => { + const mockResponse = { + choices: [{ + message: {} + }], + }; + + mockCreateCompletion.mockResolvedValue(mockResponse); + + await expect(classifier.processRequest(inputText, chatHistory)) + .rejects.toThrow('No valid tool call found in the response'); + }); + + it('should throw an error if tool input does not match expected structure', async () => { + const mockResponse = { + choices: [{ + message: { + tool_calls: [{ + function: { + name: 'analyzePrompt', + arguments: JSON.stringify({ + invalidKey: 'invalidValue', + }), + }, + }], + }, + }], + }; + + mockCreateCompletion.mockResolvedValue(mockResponse); + + await expect(classifier.processRequest(inputText, chatHistory)) + .rejects.toThrow('Tool input does not match expected structure'); + }); + + it('should throw an error if API request fails', async () => { + const errorMessage = 'API request failed'; + mockCreateCompletion.mockRejectedValue(new Error(errorMessage)); + + await expect(classifier.processRequest(inputText, chatHistory)) + .rejects.toThrow(errorMessage); + }); + }); +}); \ No newline at end of file From ce29f84e01f0783a4ee8c7fb84145b2b444defa6 Mon Sep 17 00:00:00 2001 From: Rajaniraiyn R Date: Tue, 19 Nov 2024 18:42:44 +0530 Subject: [PATCH 2/2] update docs --- .../built-in/openai-classifier.mdx | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 docs/src/content/docs/classifiers/built-in/openai-classifier.mdx diff --git a/docs/src/content/docs/classifiers/built-in/openai-classifier.mdx b/docs/src/content/docs/classifiers/built-in/openai-classifier.mdx new file mode 100644 index 00000000..e46a7d21 --- /dev/null +++ b/docs/src/content/docs/classifiers/built-in/openai-classifier.mdx @@ -0,0 +1,64 @@ +--- +title: OpenAI Classifier +description: How to configure the OpenAI classifier +--- + +The OpenAI Classifier is a built-in classifier for the Multi-Agent Orchestrator that leverages OpenAI's language models for intent classification. It provides robust classification capabilities using OpenAI's state-of-the-art models like GPT-4o. + +The OpenAI Classifier extends the abstract `Classifier` class and uses the OpenAI API client to process requests and classify user intents. + +## Features + +- Utilizes OpenAI's advanced models (e.g., GPT-4) for intent classification +- Configurable model selection and inference parameters +- Supports custom system prompts and variables +- Handles conversation history for context-aware classification + +### Basic Usage + +To use the OpenAIClassifier, you need to create an instance with your OpenAI API key and pass it to the Multi-Agent Orchestrator: + +import { Tabs, TabItem } from '@astrojs/starlight/components'; + + + + ```typescript + import { OpenAIClassifier } from "multi-agent-orchestrator"; + import { MultiAgentOrchestrator } from "multi-agent-orchestrator"; + + const openaiClassifier = new OpenAIClassifier({ + apiKey: 'your-openai-api-key' + }); + + const orchestrator = new MultiAgentOrchestrator({ classifier: openaiClassifier }); + ``` + + + +### Custom Configuration + +You can customize the OpenAIClassifier by providing additional options: + + + + ```typescript + const customOpenAIClassifier = new OpenAIClassifier({ + apiKey: 'your-openai-api-key', + modelId: 'gpt-4o', + inferenceConfig: { + maxTokens: 500, + temperature: 0.7, + topP: 0.9, + stopSequences: [''] + } + }); + + const orchestrator = new MultiAgentOrchestrator({ classifier: customOpenAIClassifier }); + ``` + + + +The OpenAIClassifier accepts the following configuration options: + +- `api_key` (required): Your OpenAI API key. +- `model_id` (optional): The ID of the OpenAI model to use. Defaults to GPT-4 Turbo \ No newline at end of file