Skip to content

Commit

Permalink
Merge branch 'main' into hide-collection-types
Browse files Browse the repository at this point in the history
  • Loading branch information
dmytrostruk authored Oct 13, 2023
2 parents d8fee5d + 988f65e commit d4024ab
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HuggingFaceTextCompletion(TextCompletionClientBase):
def __init__(
self,
model_id: str,
device: Optional[int] = -1,
device: Optional[int] = None,
task: Optional[str] = None,
log: Optional[Logger] = None,
model_kwargs: Dict[str, Any] = None,
Expand All @@ -35,7 +35,10 @@ def __init__(
Arguments:
model_id {str} -- Hugging Face model card string, see
https://huggingface.co/models
device {Optional[int]} -- Device to run the model on, -1 for CPU, 0+ for GPU.
device {Optional[int]} -- Device to run the model on, defaults to CPU, 0+ for GPU,
-- None if using device_map instead. (If both device and device_map
are specified, device overrides device_map. If unintended,
it can lead to unexpected behavior.)
task {Optional[str]} -- Model completion task type, options are:
- summarization: takes a long text and returns a shorter summary.
- text-generation: takes incomplete text and returns a set of completion candidates.
Expand Down Expand Up @@ -64,11 +67,15 @@ def __init__(
"Please ensure that torch and transformers are installed to use HuggingFaceTextCompletion"
)

self.device = (
"cuda:" + str(device)
if device >= 0 and torch.cuda.is_available()
else "cpu"
)
device_map = self._pipeline_kwargs.get("device_map", None)
if device is None:
self.device = "cpu" if device_map is None else None
else:
self.device = (
"cuda:" + str(device)
if device >= 0 and torch.cuda.is_available()
else "cpu"
)

self.generator = transformers.pipeline(
task=self._task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self._org_id = org_id
self._api_type = api_type
self._api_version = api_version
self._endpoint = endpoint
self._endpoint = endpoint.rstrip("/") if endpoint is not None else None
self._log = log if log is not None else NullLogger()
self._messages = []

Expand Down Expand Up @@ -271,7 +271,7 @@ async def _send_chat_request(
except Exception as ex:
raise AIException(
AIException.ErrorCodes.ServiceError,
"OpenAI service failed to complete the chat",
f"{self.__class__.__name__} failed to complete the chat",
ex,
) from ex

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self._api_key = api_key
self._api_type = api_type
self._api_version = api_version
self._endpoint = endpoint
self._endpoint = endpoint.rstrip("/") if endpoint is not None else None
self._org_id = org_id
self._log = log if log is not None else NullLogger()

Expand Down Expand Up @@ -161,7 +161,7 @@ async def _send_completion_request(
except Exception as ex:
raise AIException(
AIException.ErrorCodes.ServiceError,
"OpenAI service failed to complete the prompt",
f"{self.__class__.__name__} failed to complete the prompt",
ex,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self._api_key = api_key
self._api_type = api_type
self._api_version = api_version
self._endpoint = endpoint
self._endpoint = endpoint.rstrip("/") if endpoint is not None else None
self._org_id = org_id
self._log = log if log is not None else NullLogger()

Expand Down Expand Up @@ -81,6 +81,6 @@ async def generate_embeddings_async(
except Exception as ex:
raise AIException(
AIException.ErrorCodes.ServiceError,
"OpenAI service failed to generate embeddings",
f"{self.__class__.__name__} failed to generate embeddings",
ex,
)

0 comments on commit d4024ab

Please sign in to comment.