Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add OpenAI Classifier #86

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions docs/src/content/docs/classifiers/built-in/openai-classifier.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
title: OpenAI Classifier
description: How to configure the OpenAI classifier
---

The OpenAI Classifier is a built-in classifier for the Multi-Agent Orchestrator that leverages OpenAI's language models for intent classification. It provides robust classification capabilities using OpenAI's state-of-the-art models like GPT-4o.

The OpenAI Classifier extends the abstract `Classifier` class and uses the OpenAI API client to process requests and classify user intents.

## Features

- Utilizes OpenAI's advanced models (e.g., GPT-4) for intent classification
- Configurable model selection and inference parameters
- Supports custom system prompts and variables
- Handles conversation history for context-aware classification

### Basic Usage

To use the OpenAIClassifier, you need to create an instance with your OpenAI API key and pass it to the Multi-Agent Orchestrator:

import { Tabs, TabItem } from '@astrojs/starlight/components';

<Tabs syncKey="runtime">
<TabItem label="TypeScript" icon="seti:typescript" color="blue">
```typescript
import { OpenAIClassifier } from "multi-agent-orchestrator";
import { MultiAgentOrchestrator } from "multi-agent-orchestrator";

const openaiClassifier = new OpenAIClassifier({
apiKey: 'your-openai-api-key'
});

const orchestrator = new MultiAgentOrchestrator({ classifier: openaiClassifier });
```
</TabItem>
</Tabs>

### Custom Configuration

You can customize the OpenAIClassifier by providing additional options:

<Tabs syncKey="runtime">
<TabItem label="TypeScript" icon="seti:typescript" color="blue">
```typescript
const customOpenAIClassifier = new OpenAIClassifier({
apiKey: 'your-openai-api-key',
modelId: 'gpt-4o',
inferenceConfig: {
maxTokens: 500,
temperature: 0.7,
topP: 0.9,
stopSequences: ['']
}
});

const orchestrator = new MultiAgentOrchestrator({ classifier: customOpenAIClassifier });
```
</TabItem>
</Tabs>

The OpenAIClassifier accepts the following configuration options:

- `api_key` (required): Your OpenAI API key.
- `model_id` (optional): The ID of the OpenAI model to use. Defaults to GPT-4 Turbo
138 changes: 138 additions & 0 deletions typescript/src/classifiers/openAIClassifier.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import OpenAI from "openai";
import {
ConversationMessage,
OPENAI_MODEL_ID_GPT_O_MINI
} from "../types";
import { isClassifierToolInput } from "../utils/helpers";
import { Logger } from "../utils/logger";
import { Classifier, ClassifierResult } from "./classifier";

export interface OpenAIClassifierOptions {
// Optional: The ID of the OpenAI model to use for classification
// If not provided, a default model may be used
modelId?: string;

// Optional: Configuration for the inference process
inferenceConfig?: {
// Maximum number of tokens to generate in the response
maxTokens?: number;

// Controls randomness in output generation
temperature?: number;

// Controls diversity of output via nucleus sampling
topP?: number;

// Array of sequences that will stop the model from generating further tokens
stopSequences?: string[];
};

// The API key for authenticating with OpenAI's services
apiKey: string;
}

export class OpenAIClassifier extends Classifier {
private client: OpenAI;
protected inferenceConfig: {
maxTokens?: number;
temperature?: number;
topP?: number;
stopSequences?: string[];
};

private tools: OpenAI.ChatCompletionTool[] = [
{
type: "function",
function: {
name: 'analyzePrompt',
description: 'Analyze the user input and provide structured output',
parameters: {
type: 'object',
properties: {
userinput: {
type: 'string',
description: 'The original user input',
},
selected_agent: {
type: 'string',
description: 'The name of the selected agent',
},
confidence: {
type: 'number',
description: 'Confidence level between 0 and 1',
},
},
required: ['userinput', 'selected_agent', 'confidence'],
},
},
},
];

constructor(options: OpenAIClassifierOptions) {
super();

if (!options.apiKey) {
throw new Error("OpenAI API key is required");
}
this.client = new OpenAI({ apiKey: options.apiKey });
this.modelId = options.modelId || OPENAI_MODEL_ID_GPT_O_MINI;

const defaultMaxTokens = 1000;
this.inferenceConfig = {
maxTokens: options.inferenceConfig?.maxTokens ?? defaultMaxTokens,
temperature: options.inferenceConfig?.temperature,
topP: options.inferenceConfig?.topP,
stopSequences: options.inferenceConfig?.stopSequences,
};
}

async processRequest(
inputText: string,
chatHistory: ConversationMessage[]
): Promise<ClassifierResult> {
const messages: OpenAI.ChatCompletionMessageParam[] = [
{
role: 'system',
content: this.systemPrompt
},
{
role: 'user',
content: inputText
}
];

try {
const response = await this.client.chat.completions.create({
model: this.modelId,
messages: messages,
max_tokens: this.inferenceConfig.maxTokens,
temperature: this.inferenceConfig.temperature,
top_p: this.inferenceConfig.topP,
tools: this.tools,
tool_choice: { type: "function", function: { name: "analyzePrompt" } }
});

const toolCall = response.choices[0]?.message?.tool_calls?.[0];

if (!toolCall || toolCall.function.name !== "analyzePrompt") {
throw new Error("No valid tool call found in the response");
}

const toolInput = JSON.parse(toolCall.function.arguments);

if (!isClassifierToolInput(toolInput)) {
throw new Error("Tool input does not match expected structure");
}

const intentClassifierResult: ClassifierResult = {
selectedAgent: this.getAgentById(toolInput.selected_agent),
confidence: parseFloat(toolInput.confidence),
};
return intentClassifierResult;

} catch (error) {
Logger.logger.error("Error processing request:", error);
throw error;
}
}
}
1 change: 1 addition & 0 deletions typescript/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export { AgentResponse } from './agents/agent';

export { BedrockClassifier, BedrockClassifierOptions } from './classifiers/bedrockClassifier';
export { AnthropicClassifier, AnthropicClassifierOptions } from './classifiers/anthropicClassifier';
export { OpenAIClassifier, OpenAIClassifierOptions } from "./classifiers/openAIClassifier"

export { Retriever } from './retrievers/retriever';
export { AmazonKnowledgeBasesRetriever, AmazonKnowledgeBasesRetrieverOptions } from './retrievers/AmazonKBRetriever';
Expand Down
178 changes: 178 additions & 0 deletions typescript/tests/classifiers/OpenAIClassifier.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import { OpenAIClassifier, OpenAIClassifierOptions } from '../../src/classifiers/openAIClassifier';
import OpenAI from 'openai';
import { ConversationMessage, OPENAI_MODEL_ID_GPT_O_MINI, ParticipantRole } from "../../src/types";
import { MockAgent } from '../mock/mockAgent';

// Mock the OpenAI module
jest.mock('openai');

describe('OpenAIClassifier', () => {
let classifier: OpenAIClassifier;
let mockCreateCompletion: jest.Mock;

const defaultOptions: OpenAIClassifierOptions = {
apiKey: 'test-api-key',
};

beforeEach(() => {
// Create a mock for the create method
mockCreateCompletion = jest.fn();

// Mock the OpenAI constructor and chat.completions.create method
(OpenAI as jest.MockedClass<typeof OpenAI>).mockImplementation(() => ({
chat: {
completions: {
create: mockCreateCompletion,
},
},
} as unknown as OpenAI));

classifier = new OpenAIClassifier(defaultOptions);
});

afterEach(() => {
jest.clearAllMocks();
});

describe('constructor', () => {
it('should create an instance with default options', () => {
expect(classifier).toBeInstanceOf(OpenAIClassifier);
expect(OpenAI).toHaveBeenCalledWith({ apiKey: 'test-api-key' });
});

it('should use custom model ID if provided', () => {
const customOptions: OpenAIClassifierOptions = {
...defaultOptions,
modelId: 'custom-model-id',
};
const customClassifier = new OpenAIClassifier(customOptions);
expect(customClassifier['modelId']).toBe('custom-model-id');
});

it('should use default model ID if not provided', () => {
expect(classifier['modelId']).toBe(OPENAI_MODEL_ID_GPT_O_MINI);
});

it('should set inference config with custom values', () => {
const customOptions: OpenAIClassifierOptions = {
...defaultOptions,
inferenceConfig: {
maxTokens: 500,
temperature: 0.7,
topP: 0.9,
stopSequences: ['STOP'],
},
};
const customClassifier = new OpenAIClassifier(customOptions);
expect(customClassifier['inferenceConfig']).toEqual(customOptions.inferenceConfig);
});

it('should throw an error if API key is not provided', () => {
expect(() => new OpenAIClassifier({ apiKey: '' })).toThrow('OpenAI API key is required');
});
});

describe('processRequest', () => {
const inputText = 'Hello, how are you?';
const chatHistory: ConversationMessage[] = [];

it('should process request successfully', async () => {
const mockResponse = {
choices: [{
message: {
tool_calls: [{
function: {
name: 'analyzePrompt',
arguments: JSON.stringify({
userinput: inputText,
selected_agent: 'test-agent',
confidence: 0.95,
}),
},
}],
},
}],
};

const mockAgent = {
'test-agent': new MockAgent({
name: "test-agent",
description: 'A tech support agent',
})
};

classifier.setAgents(mockAgent);

mockCreateCompletion.mockResolvedValue(mockResponse);

const result = await classifier.processRequest(inputText, chatHistory);

expect(mockCreateCompletion).toHaveBeenCalledWith({
model: OPENAI_MODEL_ID_GPT_O_MINI,
max_tokens: 1000,
messages: [
{
role: 'system',
content: classifier['systemPrompt'], // Use the actual system prompt
},
{
role: 'user',
content: inputText
}
],
temperature: undefined,
top_p: undefined,
tools: classifier['tools'], // Use the actual tools array
tool_choice: { type: "function", function: { name: "analyzePrompt" } }
});

expect(result).toEqual({
selectedAgent: expect.any(MockAgent),
confidence: 0.95,
});
});

it('should throw an error if no tool calls are found in the response', async () => {
const mockResponse = {
choices: [{
message: {}
}],
};

mockCreateCompletion.mockResolvedValue(mockResponse);

await expect(classifier.processRequest(inputText, chatHistory))
.rejects.toThrow('No valid tool call found in the response');
});

it('should throw an error if tool input does not match expected structure', async () => {
const mockResponse = {
choices: [{
message: {
tool_calls: [{
function: {
name: 'analyzePrompt',
arguments: JSON.stringify({
invalidKey: 'invalidValue',
}),
},
}],
},
}],
};

mockCreateCompletion.mockResolvedValue(mockResponse);

await expect(classifier.processRequest(inputText, chatHistory))
.rejects.toThrow('Tool input does not match expected structure');
});

it('should throw an error if API request fails', async () => {
const errorMessage = 'API request failed';
mockCreateCompletion.mockRejectedValue(new Error(errorMessage));

await expect(classifier.processRequest(inputText, chatHistory))
.rejects.toThrow(errorMessage);
});
});
});