Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow using Vector Stores directly as Tools #12311

Merged
merged 19 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cypress/composables/ndv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Getters
*/

import { getVisibleSelect } from '../utils';
import { getVisibleSelect } from '../utils/popper';

export function getCredentialSelect(eq = 0) {
return cy.getByTestId('node-credentials-select').eq(eq);
Expand Down
21 changes: 20 additions & 1 deletion cypress/composables/workflow.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { getManualChatModal } from './modals/chat-modal';
import { clickGetBackToCanvas, getParameterInputByName } from './ndv';
import { ROUTES } from '../constants';

/**
Expand Down Expand Up @@ -127,7 +128,7 @@ export function navigateToNewWorkflowPage(preventNodeViewUnload = true) {
});
}

export function addSupplementalNodeToParent(
function connectNodeToParent(
nodeName: string,
endpointType: EndpointType,
parentNodeName: string,
Expand All @@ -141,6 +142,15 @@ export function addSupplementalNodeToParent(
} else {
getNodeCreatorItems().contains(nodeName).click();
}
}

export function addSupplementalNodeToParent(
nodeName: string,
endpointType: EndpointType,
parentNodeName: string,
exactMatch = false,
) {
connectNodeToParent(nodeName, endpointType, parentNodeName, exactMatch);
getConnectionBySourceAndTarget(parentNodeName, nodeName).should('exist');
}

Expand All @@ -160,6 +170,15 @@ export function addToolNodeToParent(nodeName: string, parentNodeName: string) {
addSupplementalNodeToParent(nodeName, 'ai_tool', parentNodeName);
}

export function addVectorStoreToolToParent(nodeName: string, parentNodeName: string) {
connectNodeToParent(nodeName, 'ai_tool', parentNodeName, false);
getParameterInputByName('mode')
.find('input')
.should('have.value', 'Retrieve Documents (As Tool for AI Agent)');
clickGetBackToCanvas();
getConnectionBySourceAndTarget(nodeName, parentNodeName).should('exist');
}

export function addOutputParserNodeToParent(nodeName: string, parentNodeName: string) {
addSupplementalNodeToParent(nodeName, 'ai_outputParser', parentNodeName);
}
Expand Down
14 changes: 12 additions & 2 deletions cypress/e2e/4-node-creator.cy.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { clickGetBackToCanvas } from '../composables/ndv';
import {
addNodeToCanvas,
addRetrieverNodeToParent,
addVectorStoreNodeToParent,
addVectorStoreToolToParent,
getNodeCreatorItems,
} from '../composables/workflow';
import { IF_NODE_NAME } from '../constants';
import { AGENT_NODE_NAME, IF_NODE_NAME, MANUAL_CHAT_TRIGGER_NODE_NAME } from '../constants';
import { NodeCreator } from '../pages/features/node-creator';
import { NDV } from '../pages/ndv';
import { WorkflowPage as WorkflowPageClass } from '../pages/workflow';
Expand Down Expand Up @@ -536,12 +538,20 @@ describe('Node Creator', () => {
});
});

it('should add node directly for sub-connection', () => {
it('should add node directly for sub-connection as vector store', () => {
addNodeToCanvas('Question and Answer Chain', true);
addRetrieverNodeToParent('Vector Store Retriever', 'Question and Answer Chain');
cy.realPress('Escape');
addVectorStoreNodeToParent('In-Memory Vector Store', 'Vector Store Retriever');
cy.realPress('Escape');
WorkflowPage.getters.canvasNodes().should('have.length', 4);
});

it('should add node directly for sub-connection as tool', () => {
addNodeToCanvas(MANUAL_CHAT_TRIGGER_NODE_NAME, true);
addNodeToCanvas(AGENT_NODE_NAME, true, true);
clickGetBackToCanvas();

addVectorStoreToolToParent('In-Memory Vector Store', AGENT_NODE_NAME);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ import { getConnectionHintNoticeField } from '@utils/sharedFields';

export class ToolVectorStore implements INodeType {
description: INodeTypeDescription = {
displayName: 'Vector Store Tool',
displayName: 'Vector Store Question Answer Tool',
name: 'toolVectorStore',
icon: 'fa:database',
iconColor: 'black',
group: ['transform'],
version: [1],
description: 'Retrieve context from vector store',
description: 'Answer questions with a vector store',
defaults: {
name: 'Vector Store Tool',
name: 'Answer questions with a vector store',
},
codex: {
categories: ['AI'],
Expand Down Expand Up @@ -60,20 +60,23 @@ export class ToolVectorStore implements INodeType {
properties: [
getConnectionHintNoticeField([NodeConnectionType.AiAgent]),
{
displayName: 'Name',
displayName: 'Data Name',
name: 'name',
type: 'string',
default: '',
placeholder: 'e.g. company_knowledge_base',
placeholder: 'e.g. users_info',
validateType: 'string-alphanumeric',
description: 'Name of the vector store',
description:
'Name of the data in vector store. This will be used to fill this tool description: Useful for when you need to answer questions about [name]. Whenever you need information about [data description], you should ALWAYS use this. Input should be a fully formed question.',
},
{
displayName: 'Description',
displayName: 'Description of Data',
name: 'description',
type: 'string',
default: '',
placeholder: 'Retrieves data about [insert information about your data here]...',
placeholder: "[Describe your data here, e.g. a user's name, email, etc.]",
description:
'Describe the data in vector store. This will be used to fill this tool description: Useful for when you need to answer questions about [name]. Whenever you need information about [data description], you should ALWAYS use this. Input should be a fully formed question.',
typeOptions: {
rows: 3,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ export class VectorStorePGVector extends createVectorStoreNode({
testedBy: 'postgresConnectionTest',
},
],
operationModes: ['load', 'insert', 'retrieve'],
operationModes: ['load', 'insert', 'retrieve', 'retrieve-as-tool'],
},
sharedFields,
insertFields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export class VectorStorePinecone extends createVectorStoreNode({
required: true,
},
],
operationModes: ['load', 'insert', 'retrieve', 'update'],
operationModes: ['load', 'insert', 'retrieve', 'update', 'retrieve-as-tool'],
},
methods: { listSearch: { pineconeIndexSearch } },
retrieveFields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ export class VectorStoreSupabase extends createVectorStoreNode({
required: true,
},
],
operationModes: ['load', 'insert', 'retrieve', 'update'],
operationModes: ['load', 'insert', 'retrieve', 'update', 'retrieve-as-tool'],
},
methods: {
listSearch: { supabaseTableNameSearch },
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import type { DocumentInterface } from '@langchain/core/documents';
import type { Embeddings } from '@langchain/core/embeddings';
import type { VectorStore } from '@langchain/core/vectorstores';
import { mock } from 'jest-mock-extended';
import type { DynamicTool } from 'langchain/tools';
import type { ISupplyDataFunctions, NodeParameterValueType } from 'n8n-workflow';

import type { VectorStoreNodeConstructorArgs } from './createVectorStoreNode';
import { createVectorStoreNode } from './createVectorStoreNode';

jest.mock('@utils/logWrapper', () => ({
logWrapper: jest.fn().mockImplementation((val: DynamicTool) => ({ logWrapped: val })),
}));

const DEFAULT_PARAMETERS = {
options: {},
topK: 1,
};

const MOCK_DOCUMENTS: Array<[DocumentInterface, number]> = [
[
{
pageContent: 'first page',
metadata: {
id: 123,
},
},
0,
],
[
{
pageContent: 'second page',
metadata: {
id: 567,
},
},
0,
],
];

const MOCK_SEARCH_VALUE = 'search value';
const MOCK_EMBEDDED_SEARCH_VALUE = [1, 2, 3];

describe('createVectorStoreNode', () => {
const vectorStore = mock<VectorStore>({
similaritySearchVectorWithScore: jest.fn().mockResolvedValue(MOCK_DOCUMENTS),
});

const vectorStoreNodeArgs = mock<VectorStoreNodeConstructorArgs>({
sharedFields: [],
insertFields: [],
loadFields: [],
retrieveFields: [],
updateFields: [],
getVectorStoreClient: jest.fn().mockReturnValue(vectorStore),
});

const embeddings = mock<Embeddings>({
embedQuery: jest.fn().mockResolvedValue(MOCK_EMBEDDED_SEARCH_VALUE),
});

const context = mock<ISupplyDataFunctions>({
getNodeParameter: jest.fn(),
getInputConnectionData: jest.fn().mockReturnValue(embeddings),
});

describe('retrieve mode', () => {
it('supplies vector store as data', async () => {
// ARRANGE
const parameters: Record<string, NodeParameterValueType | object> = {
...DEFAULT_PARAMETERS,
mode: 'retrieve',
};
context.getNodeParameter.mockImplementation(
(parameterName: string): NodeParameterValueType | object => parameters[parameterName],
);

// ACT
const VectorStoreNodeType = createVectorStoreNode(vectorStoreNodeArgs);
const nodeType = new VectorStoreNodeType();
const data = await nodeType.supplyData.call(context, 1);
const wrappedVectorStore = (data.response as { logWrapped: VectorStore }).logWrapped;

// ASSERT
expect(wrappedVectorStore).toEqual(vectorStore);
expect(vectorStoreNodeArgs.getVectorStoreClient).toHaveBeenCalled();
});
});

describe('retrieve-as-tool mode', () => {
it('supplies DynamicTool that queries vector store and returns documents with metadata', async () => {
// ARRANGE
const parameters: Record<string, NodeParameterValueType | object> = {
...DEFAULT_PARAMETERS,
mode: 'retrieve-as-tool',
description: 'tool description',
toolName: 'tool name',
includeDocumentMetadata: true,
};
context.getNodeParameter.mockImplementation(
(parameterName: string): NodeParameterValueType | object => parameters[parameterName],
);

// ACT
const VectorStoreNodeType = createVectorStoreNode(vectorStoreNodeArgs);
const nodeType = new VectorStoreNodeType();
const data = await nodeType.supplyData.call(context, 1);
const tool = (data.response as { logWrapped: DynamicTool }).logWrapped;
const output = await tool?.func(MOCK_SEARCH_VALUE);

// ASSERT
expect(tool?.getName()).toEqual(parameters.toolName);
expect(tool?.description).toEqual(parameters.toolDescription);
expect(embeddings.embedQuery).toHaveBeenCalledWith(MOCK_SEARCH_VALUE);
expect(vectorStore.similaritySearchVectorWithScore).toHaveBeenCalledWith(
MOCK_EMBEDDED_SEARCH_VALUE,
parameters.topK,
parameters.filter,
);
expect(output).toEqual([
{ type: 'text', text: JSON.stringify(MOCK_DOCUMENTS[0][0]) },
{ type: 'text', text: JSON.stringify(MOCK_DOCUMENTS[1][0]) },
]);
});

it('supplies DynamicTool that queries vector store and returns documents without metadata', async () => {
// ARRANGE
const parameters: Record<string, NodeParameterValueType | object> = {
...DEFAULT_PARAMETERS,
mode: 'retrieve-as-tool',
description: 'tool description',
toolName: 'tool name',
includeDocumentMetadata: false,
};
context.getNodeParameter.mockImplementation(
(parameterName: string): NodeParameterValueType | object => parameters[parameterName],
);

// ACT
const VectorStoreNodeType = createVectorStoreNode(vectorStoreNodeArgs);
const nodeType = new VectorStoreNodeType();
const data = await nodeType.supplyData.call(context, 1);
const tool = (data.response as { logWrapped: DynamicTool }).logWrapped;
const output = await tool?.func(MOCK_SEARCH_VALUE);

// ASSERT
expect(tool?.getName()).toEqual(parameters.toolName);
expect(tool?.description).toEqual(parameters.toolDescription);
expect(embeddings.embedQuery).toHaveBeenCalledWith(MOCK_SEARCH_VALUE);
expect(vectorStore.similaritySearchVectorWithScore).toHaveBeenCalledWith(
MOCK_EMBEDDED_SEARCH_VALUE,
parameters.topK,
parameters.filter,
);
expect(output).toEqual([
{ type: 'text', text: JSON.stringify({ pageContent: MOCK_DOCUMENTS[0][0].pageContent }) },
{ type: 'text', text: JSON.stringify({ pageContent: MOCK_DOCUMENTS[1][0].pageContent }) },
]);
});
});
});
Loading
Loading