diff --git a/src/curate_gpt/extract/openai_extractor.py b/src/curate_gpt/extract/openai_extractor.py index 5f18dac..878ec82 100644 --- a/src/curate_gpt/extract/openai_extractor.py +++ b/src/curate_gpt/extract/openai_extractor.py @@ -2,12 +2,11 @@ import json import logging +import os from dataclasses import dataclass from typing import List -from openai import OpenAI - -client = OpenAI() +from openai import OpenAI, OpenAIError from curate_gpt.extract.extractor import AnnotatedObject, Extractor @@ -27,6 +26,18 @@ class OpenAIExtractor(Extractor): # conversation: List[Dict[str, Any]] = None # conversation_mode: bool = False + @staticmethod + def _get_openai_client(): + """ + Private method to get an instance of the OpenAI client. + """ + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + raise OpenAIError( + "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" + ) + return OpenAI(api_key=api_key) + def functions(self): return [ { @@ -91,6 +102,8 @@ def extract( } ) # print(yaml.dump(messages)) + client = self._get_openai_client() + response = client.chat.completions.create( model=self.model, functions=self.functions(),