Skip to content

Commit ef817bf

Browse files
author
Kunjan Shah
committed
Fix: code formatting via pre-commit
1 parent 4f27293 commit ef817bf

File tree

3 files changed

+26
-44
lines changed

3 files changed

+26
-44
lines changed

‎giskard/llm/client/__init__.py‎

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@ def get_default_llm_api() -> str:
3636
global _default_llm_api
3737
if _default_llm_api is None:
3838
_default_llm_api = os.getenv(
39-
"GSK_LLM_API",
40-
"azure" if "AZURE_OPENAI_API_KEY" in os.environ
41-
else "groq" if "GROQ_API_KEY" in os.environ
42-
else "openai"
39+
"GSK_LLM_API",
40+
"azure" if "AZURE_OPENAI_API_KEY" in os.environ else "groq" if "GROQ_API_KEY" in os.environ else "openai",
4341
).lower()
4442

4543
if _default_llm_api not in {"azure", "openai", "groq"}:
@@ -92,7 +90,6 @@ def get_default_client() -> LLMClient:
9290
global _default_llm_api
9391
global _default_llm_model
9492
global _disable_structured_output
95-
9693

9794
try:
9895
from .litellm import LiteLLMClient

‎giskard/llm/client/groq_client.py‎

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from typing import Optional, Sequence
22

3-
from dataclasses import asdict
43
import logging
4+
from dataclasses import asdict
55

66
from ..config import LLMConfigurationError
77
from ..errors import LLMImportError
88
from . import LLMClient
99
from .base import ChatMessage
1010

1111
try:
12-
from groq import Groq
1312
import groq
13+
from groq import Groq
1414
except ImportError as err:
1515
raise LLMImportError(flavor="llm") from err
1616

@@ -29,24 +29,22 @@
2929

3030
logger = logging.getLogger(__name__)
3131

32+
3233
class GroqClient(LLMClient):
3334
def __init__(
34-
self,
35+
self,
3536
model: str = "llama-3.3-70b-versatile", # Default model for Groq
3637
client: Groq = None,
37-
#json_mode: Optional[bool] = None
38+
# json_mode: Optional[bool] = None
3839
):
3940
logger.info(f"Initializing GroqClient with model: {model}")
4041
self.model = model
4142
self._client = client or Groq()
42-
logger.info("GroqClient initialized successfully")
43-
43+
logger.info("GroqClient initialized successfully")
44+
4445
def get_config(self) -> dict:
4546
"""Return the configuration of the LLM client."""
46-
return {
47-
"client_type": self.__class__.__name__,
48-
"model": self.model
49-
}
47+
return {"client_type": self.__class__.__name__, "model": self.model}
5048

5149
def complete(
5250
self,
@@ -59,12 +57,12 @@ def complete(
5957
) -> ChatMessage:
6058
logger.info(f"GroqClient.complete called with model: {self.model}")
6159
logger.info(f"Messages: {messages}")
62-
60+
6361
extra_params = dict()
6462

65-
extra_params["seed"] = seed
63+
extra_params["seed"] = seed
6664

67-
if format in {"json", "json_object"}:
65+
if format in {"json", "json_object"}:
6866
extra_params["response_format"] = {"type": "json_object"}
6967

7068
try:
@@ -75,16 +73,16 @@ def complete(
7573
max_tokens=max_tokens,
7674
**extra_params,
7775
)
78-
79-
except groq.AuthenticationError as err:
76+
77+
except groq.AuthenticationError as err:
8078
raise LLMConfigurationError(AUTH_ERROR_MESSAGE) from err
81-
79+
8280
except groq.BadRequestError as err:
8381
if format in {"json", "json_object"}:
8482
raise LLMConfigurationError(
8583
f"Model '{self.model}' does not support JSON output or the request format is incorrect.\n\n{JSON_MODE_GUIDANCE}"
8684
) from err
87-
raise
85+
raise
8886

8987
self.logger.log_call(
9088
prompt_tokens=completion.usage.prompt_tokens,

‎tests/llm/test_llm_client.py‎

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from importlib.util import find_spec
23
from unittest.mock import MagicMock, Mock, patch
34

45
import pydantic
@@ -11,9 +12,8 @@
1112
from giskard.llm.client import ChatMessage
1213
from giskard.llm.client.bedrock import ClaudeBedrockClient
1314
from giskard.llm.client.gemini import GeminiClient
14-
from giskard.llm.client.openai import OpenAIClient
1515
from giskard.llm.client.groq_client import GroqClient
16-
from importlib.util import find_spec
16+
from giskard.llm.client.openai import OpenAIClient
1717

1818
PYDANTIC_V2 = pydantic.__version__.startswith("2.")
1919

@@ -241,46 +241,33 @@ def test_gemini_client():
241241
assert isinstance(res, ChatMessage)
242242
assert res.content == "This is a test!"
243243

244+
244245
# Check if groq is installed
245246
has_groq = find_spec("groq") is not None
246247

247-
@pytest.mark.skipif(not PYDANTIC_V2 or not has_groq,
248-
reason="Groq client test requires Pydantic v2 and groq package")
248+
249+
@pytest.mark.skipif(not PYDANTIC_V2 or not has_groq, reason="Groq client test requires Pydantic v2 and groq package")
249250
def test_groq_client():
250251
# Mock the Groq response
251252
demo_response = Mock()
252253
demo_response.usage = Mock(prompt_tokens=13, completion_tokens=7)
253-
demo_response.choices = [
254-
Mock(
255-
message=Mock(
256-
role="assistant",
257-
content="This is a test!"
258-
)
259-
)
260-
]
254+
demo_response.choices = [Mock(message=Mock(role="assistant", content="This is a test!"))]
261255

262256
# Mock the Groq client
263257
mock_client = Mock()
264258
mock_client.chat.completions.create.return_value = demo_response
265-
259+
266260
client = GroqClient(model="llama-3.3-70b-versatile", client=mock_client)
267261

268262
# Call the complete method
269-
res = client.complete(
270-
[ChatMessage(role="user", content="Hello")],
271-
temperature=0.7,
272-
format="json",
273-
max_tokens=100
274-
)
263+
res = client.complete([ChatMessage(role="user", content="Hello")], temperature=0.7, format="json", max_tokens=100)
275264

276265
# Assert json_object format was passed
277266
assert mock_client.chat.completions.create.call_args[1]["response_format"] == {"type": "json_object"}
278267

279268
# Assert that create was called with correct arguments
280269
mock_client.chat.completions.create.assert_called_once()
281-
assert mock_client.chat.completions.create.call_args[1]["messages"] == [
282-
{"role": "user", "content": "Hello"}
283-
]
270+
assert mock_client.chat.completions.create.call_args[1]["messages"] == [{"role": "user", "content": "Hello"}]
284271
assert mock_client.chat.completions.create.call_args[1]["temperature"] == 0.7
285272
assert mock_client.chat.completions.create.call_args[1]["max_tokens"] == 100
286273
assert mock_client.chat.completions.create.call_args[1]["model"] == "llama-3.3-70b-versatile"

0 commit comments

Comments
 (0)