Skip to content

Commit fa9d286

Browse files
author
kaiprvn
committed
Fix: resolve scipy.stats imports for compatibility
1 parent 752a81f commit fa9d286

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

‎giskard/testing/tests/drift.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
import numpy as np
1111
import pandas as pd
12-
from scipy.stats import chi2, ks_2samp
13-
from scipy.stats.stats import Ks_2sampResult, wasserstein_distance
12+
from scipy.stats import chi2, ks_2samp, wasserstein_distance
1413

1514
from giskard.core.test_result import TestMessage, TestMessageLevel, TestResult
1615
from giskard.datasets.base import Dataset
@@ -99,9 +98,10 @@ def _calculate_drift_psi(actual_series, reference_series, max_categories):
9998
return total_psi, pd.DataFrame(output_data)
10099

101100

102-
def _calculate_ks(actual_series, reference_series) -> Ks_2sampResult:
103-
return ks_2samp(reference_series, actual_series)
101+
from scipy.stats import KstestResult
104102

103+
def _calculate_ks(actual_series, reference_series) -> "KstestResult":
104+
return ks_2samp(reference_series, actual_series)
105105

106106
def _calculate_earth_movers_distance(actual_series, reference_series):
107107
unique_reference = np.unique(reference_series)
@@ -806,7 +806,7 @@ def test_drift_prediction_ks(
806806
else pd.Series(model.predict(actual_dataset).prediction)
807807
)
808808

809-
result: Ks_2sampResult = _calculate_ks(prediction_reference, prediction_actual)
809+
result: KstestResult = _calculate_ks(prediction_reference, prediction_actual)
810810

811811
passed = True if threshold is None else bool(result.pvalue >= threshold)
812812

0 commit comments

Comments
 (0)