|
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import pandas as pd |
12 | | -from scipy.stats import chi2, ks_2samp |
13 | | -from scipy.stats.stats import Ks_2sampResult, wasserstein_distance |
| 12 | +from scipy.stats import chi2, ks_2samp, wasserstein_distance |
14 | 13 |
|
15 | 14 | from giskard.core.test_result import TestMessage, TestMessageLevel, TestResult |
16 | 15 | from giskard.datasets.base import Dataset |
@@ -99,9 +98,10 @@ def _calculate_drift_psi(actual_series, reference_series, max_categories): |
99 | 98 | return total_psi, pd.DataFrame(output_data) |
100 | 99 |
|
101 | 100 |
|
102 | | -def _calculate_ks(actual_series, reference_series) -> Ks_2sampResult: |
103 | | - return ks_2samp(reference_series, actual_series) |
| 101 | +from scipy.stats import KstestResult |
104 | 102 |
|
| 103 | +def _calculate_ks(actual_series, reference_series) -> "KstestResult": |
| 104 | + return ks_2samp(reference_series, actual_series) |
105 | 105 |
|
106 | 106 | def _calculate_earth_movers_distance(actual_series, reference_series): |
107 | 107 | unique_reference = np.unique(reference_series) |
@@ -806,7 +806,7 @@ def test_drift_prediction_ks( |
806 | 806 | else pd.Series(model.predict(actual_dataset).prediction) |
807 | 807 | ) |
808 | 808 |
|
809 | | - result: Ks_2sampResult = _calculate_ks(prediction_reference, prediction_actual) |
| 809 | + result: KstestResult = _calculate_ks(prediction_reference, prediction_actual) |
810 | 810 |
|
811 | 811 | passed = True if threshold is None else bool(result.pvalue >= threshold) |
812 | 812 |
|
|
0 commit comments