From 848d81bd3c903662dbeeda170d06d1449d626be6 Mon Sep 17 00:00:00 2001 From: Rohan Shah Date: Wed, 23 Oct 2024 16:27:20 -0400 Subject: [PATCH] Add imports (#161) ## Problem Add four endpoints of the `BulkOperationsApi`. ## Solution Added the following four endpoints of the `BulkOperationsApi`: 1. `startImport(String uri, String integrationId, ImportErrorMode.OnErrorEnum errorMode)` 2. `describeImport(Integer limit, String paginationToken)` 3. `listImport(String id)` 4. `cancelImport(String id)` ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [X] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update - [ ] Infrastructure change (CI configs, etc) - [ ] Non-code change (docs, etc) - [ ] None of the above: (explain here) ## Test Plan Added unit tests. --- README.md | 76 +++++++ .../integration/dataPlane/QueryErrorTest.java | 4 +- .../dataPlane/UpsertErrorTest.java | 4 +- .../java/io/pinecone/clients/AsyncIndex.java | 200 +++++++++++++++++- .../java/io/pinecone/clients/Pinecone.java | 2 +- .../java/io/pinecone/clients/ImportsTest.java | 161 ++++++++++++++ 6 files changed, 443 insertions(+), 4 deletions(-) create mode 100644 src/test/java/io/pinecone/clients/ImportsTest.java diff --git a/README.md b/README.md index e019d981..d93e5bca 100644 --- a/README.md +++ b/README.md @@ -643,6 +643,82 @@ RerankResult result = inference.rerank(model, query, documents, rankFields, topN System.out.println(result.getData()); ``` +# Imports +## Start an import + +The following example initiates an asynchronous import of vectors from object storage into the index. + +```java +import org.openapitools.db_data.client.ApiException; +import org.openapitools.db_data.client.model.ImportErrorMode; +import org.openapitools.db_data.client.model.StartImportResponse; +... + +// Initialize pinecone object +Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build(); +// Get async imports connection object +AsyncIndex asyncIndex = pinecone.getAsyncIndexConnection("PINECONE_INDEX_NAME"); + +// s3 uri +String uri = "s3://path/to/file.parquet"; + +// Start an import +StartImportResponse response = asyncIndex.startImport(uri, "123-456-789", ImportErrorMode.OnErrorEnum.CONTINUE); +``` + +## List imports + +The following example lists all recent and ongoing import operations for the specified index. + +```java +import org.openapitools.db_data.client.ApiException; +import org.openapitools.db_data.client.model.ListImportsResponse; +... + +// Initialize pinecone object +Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build(); +// Get async imports connection object +AsyncIndex asyncIndex = pinecone.getAsyncIndexConnection("PINECONE_INDEX_NAME"); + +// List imports +ListImportsResponse response = asyncIndex.listImports(100, "some-pagination-token"); +``` + +## Describe an import + +The following example retrieves detailed information about a specific import operation using its unique identifier. + +```java +import org.openapitools.db_data.client.ApiException; +import org.openapitools.db_data.client.model.ImportModel; +... + +// Initialize pinecone object +Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build(); +// Get async imports connection object +AsyncIndex asyncIndex = pinecone.getAsyncIndexConnection("PINECONE_INDEX_NAME"); + +// Describe import +ImportModel importDetails = asyncIndex.describeImport("1"); +``` + +## Cancel an import + +The following example attempts to cancel an ongoing import operation using its unique identifier. + +```java +import org.openapitools.db_data.client.ApiException; +... + +// Initialize pinecone object +Pinecone pinecone = new Pinecone.Builder("PINECONE_API_KEY").build(); +// Get async imports connection object +AsyncIndex asyncIndex = pinecone.getAsyncIndexConnection("PINECONE_INDEX_NAME"); + +// Cancel import +asyncIndex.cancelImport("2"); +``` + ## Examples - The data and control plane operation examples can be found in `io/pinecone/integration` folder. \ No newline at end of file diff --git a/src/integration/java/io/pinecone/integration/dataPlane/QueryErrorTest.java b/src/integration/java/io/pinecone/integration/dataPlane/QueryErrorTest.java index 88d01cf7..59283576 100644 --- a/src/integration/java/io/pinecone/integration/dataPlane/QueryErrorTest.java +++ b/src/integration/java/io/pinecone/integration/dataPlane/QueryErrorTest.java @@ -2,6 +2,7 @@ import io.pinecone.clients.AsyncIndex; import io.pinecone.clients.Index; +import io.pinecone.configs.PineconeConfig; import io.pinecone.configs.PineconeConnection; import io.pinecone.exceptions.PineconeValidationException; import io.pinecone.proto.VectorServiceGrpc; @@ -24,6 +25,7 @@ public class QueryErrorTest { @BeforeAll public static void setUp() throws IOException, InterruptedException { + PineconeConfig config = mock(PineconeConfig.class); PineconeConnection connectionMock = mock(PineconeConnection.class); VectorServiceGrpc.VectorServiceBlockingStub stubMock = mock(VectorServiceGrpc.VectorServiceBlockingStub.class); @@ -33,7 +35,7 @@ public static void setUp() throws IOException, InterruptedException { when(connectionMock.getAsyncStub()).thenReturn(asyncStubMock); index = new Index(connectionMock, "some-index-name"); - asyncIndex = new AsyncIndex(connectionMock, "some-index-name"); + asyncIndex = new AsyncIndex(config, connectionMock, "some-index-name"); } @Test diff --git a/src/integration/java/io/pinecone/integration/dataPlane/UpsertErrorTest.java b/src/integration/java/io/pinecone/integration/dataPlane/UpsertErrorTest.java index 267c4a54..1cba4d0a 100644 --- a/src/integration/java/io/pinecone/integration/dataPlane/UpsertErrorTest.java +++ b/src/integration/java/io/pinecone/integration/dataPlane/UpsertErrorTest.java @@ -2,6 +2,7 @@ import io.pinecone.clients.AsyncIndex; import io.pinecone.clients.Index; +import io.pinecone.configs.PineconeConfig; import io.pinecone.configs.PineconeConnection; import io.pinecone.exceptions.PineconeException; import io.pinecone.exceptions.PineconeValidationException; @@ -27,6 +28,7 @@ public class UpsertErrorTest { @BeforeAll public static void setUp() throws IOException, InterruptedException { + PineconeConfig config = mock(PineconeConfig.class); PineconeConnection connectionMock = mock(PineconeConnection.class); VectorServiceGrpc.VectorServiceBlockingStub stubMock = mock(VectorServiceGrpc.VectorServiceBlockingStub.class); @@ -36,7 +38,7 @@ public static void setUp() throws IOException, InterruptedException { when(connectionMock.getAsyncStub()).thenReturn(asyncStubMock); index = new Index(connectionMock, "some-index-name"); - asyncIndex = new AsyncIndex(connectionMock, "some-index-name"); + asyncIndex = new AsyncIndex(config, connectionMock, "some-index-name"); } @Test diff --git a/src/main/java/io/pinecone/clients/AsyncIndex.java b/src/main/java/io/pinecone/clients/AsyncIndex.java index 8d822b6d..08514916 100644 --- a/src/main/java/io/pinecone/clients/AsyncIndex.java +++ b/src/main/java/io/pinecone/clients/AsyncIndex.java @@ -5,14 +5,32 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.Struct; import io.pinecone.commons.IndexInterface; +import io.pinecone.configs.PineconeConfig; import io.pinecone.configs.PineconeConnection; import io.pinecone.exceptions.PineconeValidationException; import io.pinecone.proto.*; +import io.pinecone.proto.DeleteRequest; +import io.pinecone.proto.DescribeIndexStatsRequest; +import io.pinecone.proto.FetchResponse; +import io.pinecone.proto.ListResponse; +import io.pinecone.proto.QueryRequest; +import io.pinecone.proto.QueryResponse; +import io.pinecone.proto.UpdateRequest; +import io.pinecone.proto.UpsertRequest; +import io.pinecone.proto.UpsertResponse; import io.pinecone.unsigned_indices_model.QueryResponseWithUnsignedIndices; import io.pinecone.unsigned_indices_model.VectorWithUnsignedIndices; +import okhttp3.OkHttpClient; +import org.openapitools.db_data.client.ApiClient; +import org.openapitools.db_data.client.ApiException; +import org.openapitools.db_data.client.Configuration; +import org.openapitools.db_data.client.api.BulkOperationsApi; +import org.openapitools.db_data.client.model.*; import java.util.List; +import static io.pinecone.clients.Pinecone.buildOkHttpClient; + /** * A client for interacting with a Pinecone index via GRPC asynchronously. Allows for upserting, querying, fetching, updating, and deleting vectors. @@ -38,6 +56,7 @@ public class AsyncIndex implements IndexInterface list(String namespace, String prefix, Stri return asyncStub.list(listRequest); } + /** + *

Initiates an asynchronous import of vectors from object storage into a specified index.

+ * + *

The method constructs a {@link StartImportRequest} using the provided URI for the data and optional + * storage integration ID. It also allows for specifying how to respond to errors during the import process + * through the {@link ImportErrorMode}. The import operation is then initiated via a call to the + * underlying {@link BulkOperationsApi}.

+ * + *

Example: + *

{@code
+     *     import org.openapitools.db_data.client.ApiException;
+     *     import org.openapitools.db_data.client.model.ImportErrorMode;
+     *
+     *     ...
+     *
+     *     String uri = "s3://path/to/file.parquet";
+     *     String integrationId = "123-456-789";
+     *     StartImportResponse response = asyncIndex.startImport(uri, integrationId, ImportErrorMode.OnErrorEnum.CONTINUE);
+     *  }
+ * + * @param uri The URI prefix under which the data to import is available. + * @param integrationId The ID of the storage integration to access the data. Can be null or empty. + * @param errorMode Indicates how to respond to errors during the import process. Can be null. + * @return {@link StartImportResponse} containing the details of the initiated import operation. + * @throws ApiException if there are issues processing the request or communicating with the server. + * This includes network issues, server errors, or serialization issues with the request or response. + */ + public StartImportResponse startImport(String uri, String integrationId, ImportErrorMode.OnErrorEnum errorMode) throws ApiException { + StartImportRequest importRequest = new StartImportRequest(); + importRequest.setUri(uri); + if(integrationId != null && !integrationId.isEmpty()) { + importRequest.setIntegrationId(integrationId); + } + if(errorMode != null) { + ImportErrorMode importErrorMode = new ImportErrorMode().onError(errorMode); + importRequest.setErrorMode(importErrorMode); + } + + return bulkOperations.startBulkImport(importRequest); + } + + /** + *

Lists all recent and ongoing import operations for the specified index with default limit and pagination.

+ * + *

The method constructs a request to fetch a list of import operations, limited by the default value set to 100 + * number of operations to return per page. The pagination token is set to null as well by default.

+ * + * + *

Example: + *

{@code
+     *     import org.openapitools.db_data.client.ApiException;
+     *     import org.openapitools.db_data.client.model.ListImportsResponse;
+     *
+     *     ...
+     *
+     *     ListImportsResponse response = asyncIndex.listImports();
+     *  }
+ * + * @return {@link ListImportsResponse} containing the list of recent and ongoing import operations. + * @throws ApiException if there are issues processing the request or communicating with the server. + * This includes network issues, server errors, or serialization issues with the request or response. + */ + public ListImportsResponse listImports() throws ApiException { + return listImports(100, null); + } + + /** + *

Lists all recent and ongoing import operations for the specified index based on limit.

+ * + *

The method constructs a request to fetch a list of import operations, limited by the specified + * maximum number of operations to return per page. The pagination token is set to null by default.

+ * + * + *

Example: + *

{@code
+     *     import org.openapitools.db_data.client.ApiException;
+     *     import org.openapitools.db_data.client.model.ListImportsResponse;
+     *
+     *     ...
+     *     int limit = 10;
+     *     ListImportsResponse response = asyncIndex.listImports(limit);
+     *  }
+ * + * @param limit The maximum number of operations to return per page. Default is 100. + * @return {@link ListImportsResponse} containing the list of recent and ongoing import operations. + * @throws ApiException if there are issues processing the request or communicating with the server. + * This includes network issues, server errors, or serialization issues with the request or response. + */ + public ListImportsResponse listImports(Integer limit) throws ApiException { + return listImports(limit, null); + } + + /** + *

Lists all recent and ongoing import operations for the specified index.

+ * + *

The method constructs a request to fetch a list of import operations, limited by the specified + * maximum number of operations to return per page. The pagination token allows for + * deterministic pagination through the list of import operations.

+ * + *

Example: + *

{@code
+     *     import org.openapitools.db_data.client.ApiException;
+     *     import org.openapitools.db_data.client.model.ListImportsResponse;
+     *
+     *     ...
+     *     int limit = 10;
+     *     String paginationToken = "some-pagination-token";
+     *     ListImportsResponse response = asyncIndex.listImports(limit, paginationToken);
+     *  }
+ * + * @param limit The maximum number of operations to return per page. Default is 100. + * @param paginationToken The token to continue a previous listing operation. Can be null or empty. + * @return {@link ListImportsResponse} containing the list of recent and ongoing import operations. + * @throws ApiException if there are issues processing the request or communicating with the server. + * This includes network issues, server errors, or serialization issues with the request or response. + */ + public ListImportsResponse listImports(Integer limit, String paginationToken) throws ApiException { + return bulkOperations.listBulkImports(limit, paginationToken); + } + + /** + *

Retrieves detailed information about a specific import operation using its unique identifier.

+ * + *

The method constructs a request to fetch details of the specified import operation by its ID, + * allowing users to monitor the status and results of the import process.

+ * + *

Example: + *

{@code
+     *     import org.openapitools.db_data.client.ApiException;
+     *     import org.openapitools.db_data.client.model.ImportModel;
+     *
+     *     ...
+     *
+     *     String importId = "1";
+     *     ImportModel importDetails = asyncIndex.describeImport(importId);
+     *  }
+ * + * @param id The unique identifier for the import operation. + * @return {@link ImportModel} containing details of the specified import operation. + * @throws ApiException if there are issues processing the request or communicating with the server. + * This includes network issues, server errors, or serialization issues with the request or response. + */ + public ImportModel describeImport(String id) throws ApiException { + return bulkOperations.describeBulkImport(id); + } + + /** + *

Attempts to cancel an ongoing import operation using its unique identifier.

+ * + *

The method issues a request to cancel the specified import operation if it has not yet finished. + * If the operation is already completed, the method has no effect.

+ * + *

Example: + *

{@code
+     *     import org.openapitools.db_data.client.ApiException;
+     *
+     *     ...
+     *     String importId = "2";
+     *     asyncIndex.cancelImport(importId);
+     *  }
+ * + * @param id The unique identifier for the import operation to cancel. + * @throws ApiException if there are issues processing the request or communicating with the server. + * This includes network issues, server errors, or serialization issues with the request or response. + */ + public void cancelImport(String id) throws ApiException { + bulkOperations.cancelBulkImport(id); + } + /** * {@inheritDoc} * Closes the current index connection gracefully, releasing any resources associated with it. This method should diff --git a/src/main/java/io/pinecone/clients/Pinecone.java b/src/main/java/io/pinecone/clients/Pinecone.java index 16d61a3e..d3b60c1b 100644 --- a/src/main/java/io/pinecone/clients/Pinecone.java +++ b/src/main/java/io/pinecone/clients/Pinecone.java @@ -873,7 +873,7 @@ public AsyncIndex getAsyncIndexConnection(String indexName) throws PineconeValid config.setHost(getIndexHost(indexName)); PineconeConnection connection = getConnection(indexName); - return new AsyncIndex(connection, indexName); + return new AsyncIndex(config, connection, indexName); } /** diff --git a/src/test/java/io/pinecone/clients/ImportsTest.java b/src/test/java/io/pinecone/clients/ImportsTest.java new file mode 100644 index 00000000..7779f713 --- /dev/null +++ b/src/test/java/io/pinecone/clients/ImportsTest.java @@ -0,0 +1,161 @@ +package io.pinecone.clients; + +import io.pinecone.configs.PineconeConfig; +import io.pinecone.configs.PineconeConnection; +import okhttp3.OkHttpClient; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.openapitools.db_data.client.ApiException; +import org.openapitools.db_data.client.api.BulkOperationsApi; +import org.openapitools.db_data.client.model.*; + +import java.time.OffsetDateTime; +import java.time.format.DateTimeFormatter; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +public class ImportsTest { + + private BulkOperationsApi bulkOperationsApiMock; + private AsyncIndex asyncIndex; + + @BeforeEach + public void setUp() { + PineconeConnection connectionMock = Mockito.mock(PineconeConnection.class); + PineconeConfig configMock = Mockito.mock(PineconeConfig.class); + + bulkOperationsApiMock = mock(BulkOperationsApi.class); + OkHttpClient httpClientMock = mock(OkHttpClient.class); + + when(configMock.getCustomOkHttpClient()).thenReturn(httpClientMock); + when(configMock.getApiKey()).thenReturn("fake-api-key"); + when(configMock.getUserAgent()).thenReturn("fake-user-agent"); + when(configMock.isTLSEnabled()).thenReturn(true); + when(configMock.getHost()).thenReturn("localhost"); + + asyncIndex = new AsyncIndex(configMock, connectionMock, "test-index"); + asyncIndex.bulkOperations = bulkOperationsApiMock; // Replace with mock + } + + @Test + public void testStartImportMinimal() throws ApiException { + StartImportResponse mockResponse = new StartImportResponse(); + mockResponse.setId("1"); + + when(bulkOperationsApiMock.startBulkImport(any(StartImportRequest.class))) + .thenReturn(mockResponse); + + StartImportResponse response = asyncIndex.startImport("s3://path/to/file.parquet", null, null); + + assertEquals("1", response.getId()); + } + + @Test + public void testStartImportWithIntegrationId() throws ApiException { + StartImportResponse mockResponse = new StartImportResponse(); + mockResponse.setId("1"); + + when(bulkOperationsApiMock.startBulkImport(any(StartImportRequest.class))) + .thenReturn(mockResponse); + + StartImportResponse response = asyncIndex.startImport("s3://path/to/file.parquet", "integration-123", null); + + assertEquals("1", response.getId()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(StartImportRequest.class); + verify(bulkOperationsApiMock).startBulkImport(requestCaptor.capture()); + StartImportRequest capturedRequest = requestCaptor.getValue(); + + assertEquals("s3://path/to/file.parquet", capturedRequest.getUri()); + assertEquals("integration-123", capturedRequest.getIntegrationId()); + } + + @Test + public void testStartImportWithErrorMode() throws ApiException { + StartImportResponse mockResponse = new StartImportResponse(); + mockResponse.setId("1"); + + when(bulkOperationsApiMock.startBulkImport(any(StartImportRequest.class))) + .thenReturn(mockResponse); + + StartImportResponse response = asyncIndex.startImport("s3://path/to/file.parquet", null, ImportErrorMode.OnErrorEnum.CONTINUE); + + assertEquals("1", response.getId()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(StartImportRequest.class); + verify(bulkOperationsApiMock).startBulkImport(requestCaptor.capture()); + StartImportRequest capturedRequest = requestCaptor.getValue(); + + assertEquals(ImportErrorMode.OnErrorEnum.CONTINUE, capturedRequest.getErrorMode().getOnError()); + } + + @Test + public void testStartImportWithInvalidUri() throws ApiException { + ApiException exception = new ApiException(400, "Invalid URI"); + when(bulkOperationsApiMock.startBulkImport(any(StartImportRequest.class))) + .thenThrow(exception); + + ApiException thrownException = assertThrows(ApiException.class, () -> { + asyncIndex.startImport("invalid-uri", null, null); + }); + + assertEquals(400, thrownException.getCode()); + assert(thrownException.getLocalizedMessage().contains("Invalid URI")); + } + + @Test + public void testDescribeImport() throws ApiException { + String uri = "s3://path/to/file.parquet"; + String errorMode = "CONTINUE"; + OffsetDateTime createdAt = OffsetDateTime.parse("2024-10-24T00:00:00Z", DateTimeFormatter.ISO_OFFSET_DATE_TIME); + OffsetDateTime finishedAt = OffsetDateTime.parse("2024-10-24T05:02:00Z", DateTimeFormatter.ISO_OFFSET_DATE_TIME); + float percentComplete = 43.2f; + + ImportModel mockResponse = new ImportModel(); + mockResponse.setId("1"); + mockResponse.setRecordsImported(1000L); + mockResponse.setUri(uri); + mockResponse.setStatus(ImportModel.StatusEnum.INPROGRESS); + mockResponse.setError(errorMode); + mockResponse.setCreatedAt(createdAt); + mockResponse.setFinishedAt(finishedAt); + mockResponse.setPercentComplete(43.2f); + + when(bulkOperationsApiMock.describeBulkImport("1")).thenReturn(mockResponse); + + ImportModel response = asyncIndex.describeImport("1"); + + assertEquals("1", response.getId()); + assertEquals(1000, response.getRecordsImported()); + assertEquals(uri, response.getUri()); + assertEquals("InProgress", response.getStatus().getValue()); + assertEquals(errorMode, response.getError()); + assertEquals(createdAt, response.getCreatedAt()); + assertEquals(finishedAt, response.getFinishedAt()); + assertEquals(percentComplete, response.getPercentComplete()); + + // Verify that the describeBulkImport method was called once + verify(bulkOperationsApiMock, times(1)).describeBulkImport("1"); + } + + @Test + void testListImports() throws ApiException { + ListImportsResponse mockResponse = new ListImportsResponse(); + mockResponse.setData(Collections.singletonList(new ImportModel())); + mockResponse.setPagination(new Pagination()); + + when(bulkOperationsApiMock.listBulkImports(anyInt(), anyString())).thenReturn(mockResponse); + + ListImportsResponse response = asyncIndex.listImports(10, "next-token"); + + assertNotNull(response); + assertEquals(1, response.getData().size()); + assertNotNull(response.getPagination()); + verify(bulkOperationsApiMock, times(1)) + .listBulkImports(10, "next-token"); + } +}