Skip to content
7 changes: 5 additions & 2 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,17 +520,20 @@ def run_test_suite(
client: Optional[GiskardClient], params: websocket.TestSuiteParam, *args, **kwargs
) -> websocket.TestSuite:
log_listener = LogListener()

loaded_artifacts = dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it initialized here if it's also set in the function with the 'if none' condition ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it' done to share the cache between the global arguments and test ones


try:
tests = [
{
"test": GiskardTest.download(t.testUuid, client, None),
"arguments": parse_function_arguments(client, t.arguments),
"arguments": parse_function_arguments(client, t.arguments, loaded_artifacts),
"id": t.id,
}
for t in params.tests
]

global_arguments = parse_function_arguments(client, params.globalArguments)
global_arguments = parse_function_arguments(client, params.globalArguments, loaded_artifacts)

datasets = {arg.original_id: arg for arg in global_arguments.values() if isinstance(arg, Dataset)}
for test in tests:
Expand Down
41 changes: 34 additions & 7 deletions giskard/ml_worker/websocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import os
import shutil
import uuid
from typing import Any, Dict, List, Optional

import pandas as pd
from mlflow.store.artifact.artifact_repo import verify_artifact_path
from typing import Any, Dict, List, Optional, Callable

from giskard.client.giskard_client import GiskardClient
from giskard.core.suite import DatasetInput, ModelInput, SuiteInput
Expand Down Expand Up @@ -158,7 +158,24 @@ def map_dataset_process_function_meta_ws(callable_type):
}


def parse_function_arguments(client: Optional[GiskardClient], request_arguments: List[websocket.FuncArgument]):
def _get_or_load(loaded_artifacts: Dict[str, Dict[str, Any]], type: str, uuid: str, load_fn: Callable[[], Any]) -> Any:
if type not in loaded_artifacts:
loaded_artifacts[type] = dict()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be avoided by giving a defaultdict instead of a dict, with dict as default factory


if uuid not in loaded_artifacts[type]:
loaded_artifacts[type][uuid] = load_fn()

return loaded_artifacts[type][uuid]


def parse_function_arguments(
client: Optional[GiskardClient],
request_arguments: List[websocket.FuncArgument],
loaded_artifacts: Optional[Dict[str, Dict[str, Any]]] = None,
):
if loaded_artifacts is None:
loaded_artifacts = dict()

arguments = dict()

# Processing empty list
Expand All @@ -169,14 +186,24 @@ def parse_function_arguments(client: Optional[GiskardClient], request_arguments:
if arg.is_none:
continue
if arg.dataset is not None:
arguments[arg.name] = Dataset.download(
client,
arg.dataset.project_key,
arguments[arg.name] = _get_or_load(
loaded_artifacts,
"Dataset",
arg.dataset.id,
arg.dataset.sample,
lambda: Dataset.download(
client,
arg.dataset.project_key,
arg.dataset.id,
arg.dataset.sample,
),
)
elif arg.model is not None:
arguments[arg.name] = BaseModel.download(client, arg.model.project_key, arg.model.id)
arguments[arg.name] = _get_or_load(
loaded_artifacts,
"BaseModel",
arg.model.id,
lambda: BaseModel.download(client, arg.model.project_key, arg.model.id),
)
elif arg.slicingFunction is not None:
arguments[arg.name] = SlicingFunction.download(
arg.slicingFunction.id, client, arg.slicingFunction.project_key
Expand Down
7 changes: 6 additions & 1 deletion giskard/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,12 @@ def __init__(
if len(classification_labels) != len(set(classification_labels)):
raise ValueError("Duplicates are found in 'classification_labels', please only provide unique values.")

self._cache = ModelCache(model_type, str(self.id), cache_dir=kwargs.get("prediction_cache_dir"))
self._cache = ModelCache(
model_type,
str(self.id),
persist_cache=kwargs.get("persist_cache", False),
cache_dir=kwargs.get("prediction_cache_dir"),
)

# sklearn and catboost will fill classification_labels before this check
if model_type == SupportedModelTypes.CLASSIFICATION and not classification_labels:
Expand Down
19 changes: 14 additions & 5 deletions giskard/models/cache/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import csv
from pathlib import Path
from typing import Any, Iterable, List, Optional

import numpy as np
import pandas as pd
from typing import Any, Iterable, List, Optional

from ...client.python_utils import warning
from ...core.core import SupportedModelTypes
Expand All @@ -26,14 +26,23 @@ def flatten(xs):
class ModelCache:
_default_cache_dir_prefix = Path(settings.home_dir / settings.cache_dir / "global" / "prediction_cache")

def __init__(self, model_type: SupportedModelTypes, id: Optional[str] = None, cache_dir: Optional[Path] = None):
def __init__(
self,
model_type: SupportedModelTypes,
id: Optional[str] = None,
persist_cache: bool = False,
cache_dir: Optional[Path] = None,
):
self.id = id
self.prediction_cache = dict()

if cache_dir is None and self.id:
cache_dir = self._default_cache_dir_prefix.joinpath(self.id)
if persist_cache:
if cache_dir is None and self.id:
cache_dir = self._default_cache_dir_prefix.joinpath(self.id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be self._default_cache_dir_prefix / self.id


self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None
self.cache_file = cache_dir / CACHE_CSV_FILENAME if cache_dir else None
else:
self.cache_file = None

self.vectorized_get_cache_or_na = np.vectorize(self.get_cache_or_na, otypes=[object])
self.model_type = model_type
Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import pandas as pd
import pytest
import xxhash
from langchain import LLMChain, PromptTemplate
from langchain.llms.fake import FakeListLLM

import giskard
from giskard import Dataset, Model
from giskard.core.core import SupportedModelTypes
from giskard.models.cache import ModelCache
from langchain import LLMChain, PromptTemplate


# https://symbl.cc/fr/unicode/blocks/

Expand All @@ -31,6 +32,7 @@ def test_unicode_prediction(keys, values):
with TemporaryDirectory() as temp_cache_dir:
cache = ModelCache(
model_type=SupportedModelTypes.TEXT_GENERATION,
persist_cache=True,
cache_dir=Path(temp_cache_dir),
)
key_series = pd.Series(keys)
Expand All @@ -43,6 +45,7 @@ def test_unicode_prediction(keys, values):
warmed_up_cache = ModelCache(
id="warmed_up",
model_type=SupportedModelTypes.TEXT_GENERATION,
persist_cache=True,
cache_dir=Path(temp_cache_dir),
)
# Ensure warm up works fine
Expand Down