Skip to content

Commit 930099f

Browse files
author
Kunjan Shah
committed
feat(llm): add GroqClient support to set any Groq model as default LLMClient (#1977)
1 parent 6808a56 commit 930099f

File tree

4 files changed

+169
-19
lines changed

4 files changed

+169
-19
lines changed

‎giskard/llm/__init__.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
"set_llm_api",
1111
"set_default_embedding",
1212
"set_embedding_model",
13+
"GroqClient"
1314
]

‎giskard/llm/client/__init__.py‎

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
from .base import ChatMessage, LLMClient
88
from .logger import LLMLogger
9+
from .groq_client import GroqClient
10+
from ..config import LLMConfigurationError
11+
12+
913

1014
logger = logging.getLogger(__name__)
1115

@@ -36,12 +40,15 @@ def get_default_llm_api() -> str:
3640
global _default_llm_api
3741
if _default_llm_api is None:
3842
_default_llm_api = os.getenv(
39-
"GSK_LLM_API", "azure" if "AZURE_OPENAI_API_KEY" in os.environ else "openai"
43+
"GSK_LLM_API",
44+
"azure" if "AZURE_OPENAI_API_KEY" in os.environ
45+
else "groq" if "GROQ_API_KEY" in os.environ
46+
else "openai"
4047
).lower()
4148

42-
if _default_llm_api not in {"azure", "openai"}:
49+
if _default_llm_api not in {"azure", "openai", "groq"}:
4350
logging.warning(
44-
f"LLM-based evaluation is only working with `azure` and `openai`. Found {_default_llm_api} in GSK_LLM_API, falling back to `openai`"
51+
f"LLM-based evaluation is only working with `azure`, `openai`, 'groq'. Found {_default_llm_api} in GSK_LLM_API, falling back to `openai`"
4552
)
4653
_default_llm_api = "openai"
4754

@@ -89,27 +96,34 @@ def get_default_client() -> LLMClient:
8996
global _default_llm_api
9097
global _default_llm_model
9198
global _disable_structured_output
99+
92100

93101
if _default_client is not None:
94102
return _default_client
95103

96104
try:
97-
from .litellm import LiteLLMClient
98-
99-
if (
100-
_default_llm_api is not None
101-
and "/" in _default_llm_model
102-
and not _default_llm_model.startswith(f"{_default_llm_api}/")
103-
):
104-
raise ValueError(
105-
f"Model {_default_llm_model} is not compatible with {_default_llm_api}: https://docs.giskard.ai/en/latest/open_source/setting_up/index.html "
106-
)
107-
if _default_llm_api is not None and "/" not in _default_llm_model:
108-
_default_llm_model = f"{_default_llm_api}/{_default_llm_model}"
109-
110-
_default_client = LiteLLMClient(_default_llm_model, _disable_structured_output, _default_completion_params)
111-
except ImportError:
112-
raise ValueError(f"LLM scan using {_default_llm_model} requires litellm")
105+
if _default_llm_api == "groq":
106+
groq_api_key = os.getenv("GROQ_API_KEY")
107+
if not groq_api_key:
108+
raise LLMConfigurationError("GROQ_API_KEY environment variable is not set")
109+
_default_client = GroqClient(_default_llm_model)
110+
else:
111+
from .litellm import LiteLLMClient
112+
113+
if (
114+
_default_llm_api is not None
115+
and "/" in _default_llm_model
116+
and not _default_llm_model.startswith(f"{_default_llm_api}/")
117+
):
118+
raise ValueError(
119+
f"Model {_default_llm_model} is not compatible with {_default_llm_api}: https://docs.giskard.ai/en/latest/open_source/setting_up/index.html "
120+
)
121+
if _default_llm_api is not None and "/" not in _default_llm_model:
122+
_default_llm_model = f"{_default_llm_api}/{_default_llm_model}"
123+
124+
_default_client = LiteLLMClient(_default_llm_model, _disable_structured_output, _default_completion_params)
125+
except ImportError as e:
126+
raise ValueError(f"LLM scan using {_default_llm_model} requires appropriate client library") from e
113127

114128
return _default_client
115129

@@ -122,4 +136,5 @@ def get_default_client() -> LLMClient:
122136
"set_llm_model",
123137
"get_default_llm_api",
124138
"set_llm_api",
139+
"GroqClient",
125140
]

‎giskard/llm/client/groq_client.py‎

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Optional, Sequence
2+
3+
from dataclasses import asdict
4+
from logging import warning
5+
import logging
6+
7+
from ..config import LLMConfigurationError
8+
from ..errors import LLMImportError
9+
from . import LLMClient
10+
from .base import ChatMessage
11+
12+
try:
13+
from groq import Groq
14+
import groq
15+
except ImportError as err:
16+
raise LLMImportError(flavor="llm") from err
17+
18+
AUTH_ERROR_MESSAGE = (
19+
"Could not authenticate with Groq API. Please make sure you have configured the API key by "
20+
"setting GROQ_API_KEY in the environment."
21+
)
22+
23+
def _supports_json_format(model: str) -> bool:
24+
if "llama-3.3-70b-versatile" in model:
25+
return True
26+
27+
if model == "llama-3.1-8b-instant" or model == "gemma2-9b-it":
28+
return True
29+
30+
return False
31+
32+
logger = logging.getLogger(__name__)
33+
34+
class GroqClient(LLMClient):
35+
def __init__(
36+
self,
37+
model: str = "llama-3.3-70b-versatile", # Default model for Groq
38+
client: Groq = None,
39+
json_mode: Optional[bool] = None
40+
):
41+
logger.info(f"Initializing GroqClient with model: {model}")
42+
self.model = model
43+
self._client = client or Groq()
44+
self.json_mode = json_mode if json_mode is not None else _supports_json_format(model)
45+
logger.info("GroqClient initialized successfully")
46+
47+
def complete(
48+
self,
49+
messages: Sequence[ChatMessage],
50+
temperature: float = 1.0,
51+
max_tokens: Optional[int] = None,
52+
caller_id: Optional[str] = None,
53+
seed: Optional[int] = None,
54+
format=None,
55+
) -> ChatMessage:
56+
logger.info(f"GroqClient.complete called with model: {self.model}")
57+
logger.info(f"Messages: {messages}")
58+
59+
extra_params = dict()
60+
61+
if seed is not None:
62+
extra_params["seed"] = seed
63+
64+
if self.json_mode:
65+
if format not in (None, "json", "json_object"):
66+
warning(f"Unsupported format '{format}', ignoring.")
67+
format = None
68+
69+
if format == "json" or format == "json_object":
70+
extra_params["response_format"] = {"type": "json_object"}
71+
72+
try:
73+
completion = self._client.chat.completions.create(
74+
model=self.model,
75+
messages=[asdict(m) for m in messages],
76+
temperature=temperature,
77+
max_tokens=max_tokens,
78+
**extra_params,
79+
)
80+
except groq.AuthenticationError as err:
81+
raise LLMConfigurationError(AUTH_ERROR_MESSAGE) from err
82+
83+
self.logger.log_call(
84+
prompt_tokens=completion.usage.prompt_tokens,
85+
sampled_tokens=completion.usage.completion_tokens,
86+
model=self.model,
87+
client_class=self.__class__.__name__,
88+
caller_id=caller_id,
89+
)
90+
91+
msg = completion.choices[0].message
92+
93+
return ChatMessage(role=msg.role, content=msg.content)

‎tests/llm/test_llm_client.py‎

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from giskard.llm.client.bedrock import ClaudeBedrockClient
1313
from giskard.llm.client.gemini import GeminiClient
1414
from giskard.llm.client.openai import OpenAIClient
15+
from giskard.llm.client.groq_client import GroqClient #GroqClient
1516

1617
PYDANTIC_V2 = pydantic.__version__.startswith("2.")
1718

@@ -217,3 +218,43 @@ def test_gemini_client():
217218
# Assert that the response is a ChatMessage and has the correct content
218219
assert isinstance(res, ChatMessage)
219220
assert res.content == "This is a test!"
221+
222+
@pytest.mark.skipif(not PYDANTIC_V2, reason="Groq client test requires Pydantic v2")
223+
def test_groq_client():
224+
# Mock the Groq response
225+
demo_response = Mock()
226+
demo_response.usage = Mock(prompt_tokens=13, completion_tokens=7)
227+
demo_response.choices = [
228+
Mock(
229+
message=Mock(
230+
role="assistant",
231+
content="This is a test!"
232+
)
233+
)
234+
]
235+
236+
# Mock the Groq client
237+
mock_client = Mock()
238+
mock_client.chat.completions.create.return_value = demo_response
239+
240+
client = GroqClient(model="llama3-8b-8192", client=mock_client)
241+
242+
# Call the complete method
243+
res = client.complete(
244+
[ChatMessage(role="user", content="Hello")],
245+
temperature=0.7,
246+
max_tokens=100
247+
)
248+
249+
# Assert that create was called with correct arguments
250+
mock_client.chat.completions.create.assert_called_once()
251+
assert mock_client.chat.completions.create.call_args[1]["messages"] == [
252+
{"role": "user", "content": "Hello"}
253+
]
254+
assert mock_client.chat.completions.create.call_args[1]["temperature"] == 0.7
255+
assert mock_client.chat.completions.create.call_args[1]["max_tokens"] == 100
256+
assert mock_client.chat.completions.create.call_args[1]["model"] == "llama3-8b-8192"
257+
258+
# Assert the response is correct
259+
assert isinstance(res, ChatMessage)
260+
assert res.content == "This is a test!"

0 commit comments

Comments
 (0)