Skip to content

Commit d4ca01b

Browse files
DARREN OBERSTDARREN OBERST
authored andcommitted
adding support for custom openai client for Azure use
1 parent bca2297 commit d4ca01b

File tree

3 files changed

+147
-3
lines changed

3 files changed

+147
-3
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
""" This example shows how to use OpenAIConfigs to create a configured OpenAI client, most often used for
3+
Azure OpenAI access."""
4+
5+
import os
6+
7+
from llmware.models import ModelCatalog
8+
from llmware.configs import OpenAIConfig
9+
from openai import AzureOpenAI, OpenAI
10+
11+
12+
# to start - OpenAI client is created in OpenAI Generative and Embedding models classes at the time of inference
13+
# the client will be created as a standard OpenAI client with the api_keys passed
14+
15+
my_azure_client = OpenAIConfig().get_azure_client()
16+
print("my azure client to start: ", my_azure_client)
17+
18+
# to configure an AzureOpenAI client, two steps:
19+
# first, create the client with openai >= 1.0 python SDK, (see above) e.g.:
20+
21+
client = AzureOpenAI(
22+
azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"),
23+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
24+
api_version="2024-02-01"
25+
)
26+
27+
# second, set the azure client in OpenAIConfigs as below:
28+
OpenAIConfig().set_azure_client(client)
29+
print("my azure client - set: ", OpenAIConfig().get_azure_client())
30+
31+
# now, run the inference like any other in llmware
32+
33+
# OpenAI Generative call
34+
model = ModelCatalog().load_model("gpt-4")
35+
36+
# the model will check the value of get_azure_client() in the configs -> if set, then will use
37+
response = model.inference("What is the future of AI")
38+
print("response: ", response)
39+
40+
# OpenAI Embedding call
41+
model = ModelCatalog().load_model("text-embedding-3-small")
42+
embedding = model.embedding(["This is a sample sentence for an embedding test."])
43+
print("embedding: ", embedding)
44+
45+
# reset so you can use the standard OpenAI client
46+
OpenAIConfig().set_azure_client(None)
47+
48+
model = ModelCatalog().load_model("text-embedding-3-small", api_key="your openai api key")
49+
embedding = model.embedding(["This is a sample sentence for an embedding test."])
50+
print("embedding: ", embedding)
51+

‎llmware/configs.py‎

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,3 +896,58 @@ def get_auth_credentials(cls):
896896
@classmethod
897897
def get_auth_token_transport_header(cls):
898898
return cls._conf["auth_token_transport_header"]
899+
900+
901+
class OpenAIConfig:
902+
903+
"""Configuration object for OpenAI - primarily for configuring Azure OpenAI credentials.
904+
905+
Primary use is to setup an AzureOpenAI client that will be used in place of the standard
906+
OpenAI client.
907+
908+
Within LLMWare, the OpenAI model classes will check this config before creating a new OpenAI client.
909+
910+
If there is a client already established in _conf["openai_client"], that client will be used in the
911+
inference/embedding process.
912+
913+
For example:
914+
915+
# create your AzureOpenAI client
916+
917+
from openai import AzureOpenAI
918+
919+
client = AzureOpenAI(
920+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
921+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
922+
api_version="2024-02-01")
923+
924+
# add that client to the OpenAIConfig:
925+
926+
OpenAIConfig().set_azure_client(client)
927+
928+
"""
929+
930+
_conf = {"openai_client": None,
931+
"api_key": None, # placeholder / not used currently
932+
"api_version": None, # placeholder / not used currently
933+
"use_azure_endpoint": False} # placeholder / not used currently
934+
935+
@classmethod
936+
def get_config(cls, name):
937+
if name in cls._conf:
938+
return cls._conf[name]
939+
raise ConfigKeyException(name)
940+
941+
@classmethod
942+
def set_config(cls, name, value):
943+
cls._conf[name] = value
944+
945+
@classmethod
946+
def set_azure_client(cls, azure_client):
947+
cls._conf["openai_client"] = azure_client
948+
949+
@classmethod
950+
def get_azure_client(cls):
951+
return cls._conf["openai_client"]
952+
953+

‎llmware/models.py‎

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,6 +2564,9 @@ def inference(self, prompt, add_context=None, add_prompt_engineering=None, infer
25642564
if "max_tokens" in inference_dict:
25652565
self.target_requested_output_tokens = inference_dict["max_tokens"]
25662566

2567+
if "openai_client" in inference_dict:
2568+
self.openai_client = inference_dict["openai_client"]
2569+
25672570
# api_key
25682571
if api_key:
25692572
self.api_key = api_key
@@ -2583,6 +2586,8 @@ def inference(self, prompt, add_context=None, add_prompt_engineering=None, infer
25832586
except ImportError:
25842587
raise DependencyNotInstalledException("openai >= 1.0")
25852588

2589+
from llmware.configs import OpenAIConfig
2590+
25862591
usage = {}
25872592
time_start = time.time()
25882593

@@ -2595,7 +2600,17 @@ def inference(self, prompt, add_context=None, add_prompt_engineering=None, infer
25952600

25962601
# updated OpenAI client to >v1.0 API - create client, and returns pydantic objects
25972602

2598-
client = OpenAI(api_key=self.api_key)
2603+
azure_client = OpenAIConfig().get_azure_client()
2604+
2605+
if not azure_client:
2606+
client = OpenAI(api_key=self.api_key)
2607+
2608+
else:
2609+
2610+
logging.info("update: applying custom OpenAI client from OpenAIConfig")
2611+
2612+
client = azure_client
2613+
25992614
response = client.chat.completions.create(model=self.model_name,messages=messages,
26002615
max_tokens=self.target_requested_output_tokens)
26012616

@@ -2616,7 +2631,17 @@ def inference(self, prompt, add_context=None, add_prompt_engineering=None, infer
26162631

26172632
text_prompt = prompt_final + self.separator
26182633

2619-
client = OpenAI(api_key=self.api_key)
2634+
azure_client = OpenAIConfig().get_azure_client()
2635+
2636+
if not azure_client:
2637+
client = OpenAI(api_key=self.api_key)
2638+
2639+
else:
2640+
2641+
logging.info("update: applying custom OpenAI client from OpenAIConfig")
2642+
2643+
client = azure_client
2644+
26202645
response = client.completions.create(model=self.model_name, prompt=text_prompt,
26212646
temperature=self.temperature,
26222647
max_tokens=self.target_requested_output_tokens)
@@ -3697,6 +3722,8 @@ def embedding(self, text_sample, api_key=None):
36973722
except ImportError:
36983723
raise DependencyNotInstalledException("openai >= 1.0")
36993724

3725+
from llmware.configs import OpenAIConfig
3726+
37003727
# insert safety check here
37013728
safe_samples = []
37023729
safety_buffer = 200
@@ -3732,7 +3759,18 @@ def embedding(self, text_sample, api_key=None):
37323759
# end - safety check
37333760

37343761
# update to open >v1.0 api - create client and output is pydantic objects
3735-
client = OpenAI(api_key=self.api_key)
3762+
3763+
azure_client = OpenAIConfig().get_azure_client()
3764+
3765+
if not azure_client:
3766+
client = OpenAI(api_key=self.api_key)
3767+
3768+
else:
3769+
3770+
logging.info("update: applying custom OpenAI client from OpenAIConfig")
3771+
3772+
client = azure_client
3773+
37363774
response = client.embeddings.create(model=self.model_name, input=text_prompt)
37373775

37383776
# logging.info("update: response: %s ", response)

0 commit comments

Comments
 (0)