Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
9800018
GSK-234 added unit tests from GSK 106
princyiakov Aug 2, 2022
eec9acf
GSK-234 save_df and compress removed from utils.py
princyiakov Aug 2, 2022
fec4218
GSK-234 drift tests improved
princyiakov Aug 2, 2022
d611553
GSK-234 data_drift.yml updated from GSK 106
princyiakov Aug 2, 2022
b423631
GSK-234 prediction_drift.yml updated from GSK 106
princyiakov Aug 2, 2022
45b19d1
Merge branch 'main' into GSK-234_improve_current_tests
princyiakov Aug 2, 2022
1c820eb
GSK-232 added column property to giskard dataset
princyiakov Aug 2, 2022
806441e
GSK-232 improved heuristics tests
princyiakov Aug 2, 2022
7e19ea5
GSK-232 updated heuristic.yml
princyiakov Aug 2, 2022
68f6d67
GSK-232 removed output_df from docstrings
princyiakov Aug 2, 2022
3f26d9b
GSK-232 improved metamorphic tests
princyiakov Aug 2, 2022
a55c3df
GSK-232 updated metamorphic tests
princyiakov Aug 2, 2022
01da401
GSK-232 updated performance tests
princyiakov Aug 2, 2022
f544896
GSK-234 added column property to giskard dataset
princyiakov Aug 2, 2022
ae1f386
GSK-234 improved heuristics tests
princyiakov Aug 2, 2022
1660dd8
GSK-234 updated heuristic.yml
princyiakov Aug 2, 2022
102d056
GSK-234 removed output_df from docstrings
princyiakov Aug 2, 2022
2eba330
GSK-234 improved metamorphic tests
princyiakov Aug 2, 2022
0faef10
GSK-234 updated metamorphic tests
princyiakov Aug 2, 2022
8aadb57
GSK-234 updated performance tests
princyiakov Aug 2, 2022
a9af91b
GSK-234 heuristic docstring update
princyiakov Aug 2, 2022
37e2fbf
GSK-234 output_df removed from doctring of prediction_drift.yml
princyiakov Aug 2, 2022
df3f1aa
Merge remote-tracking branch 'origin/GSK-234_improve_current_tests' i…
princyiakov Aug 2, 2022
94e4689
GSK-234 removed unused parameters
princyiakov Aug 2, 2022
c8dd3f4
GSK-234 docstring update
princyiakov Aug 2, 2022
3e2fae5
GSK-234 restore function name to _calculate_psi
princyiakov Aug 2, 2022
d2abf45
GSK-234 removed unused parameters from drift tests
princyiakov Aug 2, 2022
6c654ad
GSK-234 input parameters types updated
princyiakov Aug 2, 2022
b6c82d6
GSK-234 removed unused _calculate_numerical_drift function
princyiakov Aug 2, 2022
5617136
GSK-234 removed output_df from docstring
princyiakov Aug 2, 2022
22ccb2d
GSK-234 removed output_df from docstring
princyiakov Aug 2, 2022
7888e25
GSK-234 reorder private and public functions
princyiakov Aug 2, 2022
b1af3ea
GSK-234 docstring update
princyiakov Aug 2, 2022
71f2fb2
GSK-234 threshold set to 0.5 and output_df removed from docstring
princyiakov Aug 3, 2022
b8e86ff
GSK-234 improved perturbations
princyiakov Aug 3, 2022
57fdb30
GSK-234 bug fix for test diff prediction
princyiakov Aug 3, 2022
a00ea16
GSK-234 zerodivisionerror fix
princyiakov Aug 3, 2022
2d9c03a
GSK-234 zerodivisionerror raising error instead of warning
princyiakov Aug 3, 2022
61bb9af
GSK-234 docstring update
princyiakov Aug 4, 2022
01b527d
GSK-234 nlp augmenter added
princyiakov Aug 4, 2022
95cd1ba
GSK-234 added optional argument in the end, removed message
princyiakov Aug 4, 2022
98cc867
GSK-234 metric update in tests
princyiakov Aug 4, 2022
e3bf0db
Merge branch 'main' into GSK-234_improve_current_tests
princyiakov Aug 22, 2022
65537d1
small cleanup
andreybavt Aug 24, 2022
5fe5bce
cleanup
andreybavt Aug 24, 2022
4741b12
Merge branch 'main' into GSK-234_improve_current_tests
princyiakov Aug 24, 2022
1914d70
GSK-234 rename other_modalities
princyiakov Aug 24, 2022
dc53ba5
GSK-234 improved error message for _validate_column_type
princyiakov Aug 24, 2022
bca4545
GSK-234 refactored signature
princyiakov Aug 24, 2022
51b62bf
GSK-234 improved message for data drifts tests
princyiakov Aug 24, 2022
d59d8ae
Merge remote-tracking branch 'origin/GSK-234_improve_current_tests' i…
princyiakov Aug 24, 2022
ee6734c
GSK-234 refactored duplicated codes
princyiakov Aug 25, 2022
a65176a
GSK-234 refactored duplicated codes
princyiakov Aug 25, 2022
1bb0d2f
GSK-234 function name update
princyiakov Aug 25, 2022
9567adb
GSK-234 sending giskard dataset instead of dataframe in yml file
princyiakov Aug 26, 2022
de82671
GSK-234 renamed function
princyiakov Aug 26, 2022
d6f1d06
small message improvement
andreybavt Aug 30, 2022
0ce74e3
Merge branch 'main' into GSK-234_improve_current_tests
andreybavt Aug 30, 2022
15695c2
added test messages to the UI
andreybavt Aug 30, 2022
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
GSK-232 improved metamorphic tests
  • Loading branch information
princyiakov committed Aug 2, 2022
commit 3f26d9b7e76dae03f342e279f6fbad7213b07e78
89 changes: 42 additions & 47 deletions giskard-ml-worker/ml_worker/testing/metamorphic_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import pandas as pd

from generated.ml_worker_pb2 import SingleTestResult
from generated.ml_worker_pb2 import SingleTestResult, TestMessage, TestMessageType

from ml_worker.core.giskard_dataset import GiskardDataset
from ml_worker.core.model import GiskardModel
from ml_worker.testing.abstract_test_collection import AbstractTestCollection
from ml_worker.testing.utils import apply_perturbation_inplace
from ml_worker.testing.utils import save_df, compress


class MetamorphicTests(AbstractTestCollection):
Expand Down Expand Up @@ -72,6 +72,8 @@ def _test_metamorphic(self,
output_sensitivity=None,
output_proba=True
) -> SingleTestResult:
actual_slice.df.reset_index(drop=True, inplace=True)

results_df, modified_rows_count = self._perturb_and_predict(actual_slice,
model,
perturbation_dict,
Expand All @@ -82,17 +84,18 @@ def _test_metamorphic(self,
model.model_type,
output_sensitivity,
flag)
failed_df = actual_slice.df.loc[failed_idx]
passed_ratio = len(passed_idx) / modified_rows_count if modified_rows_count != 0 else 1

output_df_sample = compress(save_df(failed_df))
messages = [TestMessage(
type=TestMessageType.INFO,
text=f"{modified_rows_count} number of rows were perturbed"
)]

return self.save_results(SingleTestResult(
actual_slices_size=[len(actual_slice)],
number_of_perturbed_rows=modified_rows_count,
metric=passed_ratio,
passed=passed_ratio > threshold,
output_df=output_df_sample))
messages=messages))

def test_metamorphic_invariance(self,
df: GiskardDataset,
Expand All @@ -116,30 +119,28 @@ def test_metamorphic_invariance(self,

Args:
df(GiskardDataset):
Dataset used to compute the test
Dataset used to compute the test
model(GiskardModel):
Model used to compute the test
Model used to compute the test
perturbation_dict(dict):
Dictionary of the perturbations. It provides the perturbed features as key and a perturbation lambda function as value
Dictionary of the perturbations. It provides the perturbed features as key
and a perturbation lambda function as value
threshold(float):
Threshold of the ratio of invariant rows
Threshold of the ratio of invariant rows
output_sensitivity(float):
the threshold for ratio between the difference between perturbed prediction and actual prediction over
Optional. The threshold for ratio between the difference between perturbed prediction and actual prediction over
the actual prediction for a regression model. We consider there is a prediction difference for
regression if the ratio is above the output_sensitivity of 0.1

Returns:
actual_slices_size:
total number of rows of dataframe
number_of_perturbed_rows:
number of perturbed rows
Length of actual_slice tested
message:
Test result message
metric:
the ratio of invariant rows over the perturbed rows
The ratio of unchanged rows over the perturbed rows
passed:
TRUE if passed_ratio > threshold
output_df:
dataframe containing the non-invariant rows

TRUE if metric > threshold
"""

return self._test_metamorphic(flag='Invariance',
Expand Down Expand Up @@ -173,29 +174,26 @@ def test_metamorphic_increasing(self,

Args:
df(GiskardDataset):
Dataset used to compute the test
Dataset used to compute the test
model(GiskardModel):
Model used to compute the test
Model used to compute the test
perturbation_dict(dict):
Dictionary of the perturbations. It provides the perturbed features as key
and a perturbation lambda function as value
Dictionary of the perturbations. It provides the perturbed features as key
and a perturbation lambda function as value
threshold(float):
Threshold of the ratio of increasing rows
Threshold of the ratio of increasing rows
classification_label(str):
one specific label value from the target column
Optional.One specific label value from the target column

Returns:
actual_slices_size:
total number of rows of dataframe
number_of_perturbed_rows:
number of perturbed rows
Length of actual_slice tested
message:
Test result message
metric:
the ratio of increasing rows over the perturbed rows
The ratio of increasing rows over the perturbed rows
passed:
TRUE if passed_ratio > threshold
output_df:
dataframe containing the rows whose probability doesn't increase after perturbation

TRUE if metric > threshold
"""
assert model.model_type != "classification" or str(classification_label) in model.classification_labels, \
f'"{classification_label}" is not part of model labels: {",".join(model.classification_labels)}'
Expand Down Expand Up @@ -230,29 +228,26 @@ def test_metamorphic_decreasing(self,

Args:
df(GiskardDataset):
Dataset used to compute the test
Dataset used to compute the test
model(GiskardModel):
Model used to compute the test
Model used to compute the test
perturbation_dict(dict):
Dictionary of the perturbations. It provides the perturbed features as key
and a perturbation lambda function as value
Dictionary of the perturbations. It provides the perturbed features as key
and a perturbation lambda function as value
threshold(float):
Threshold of the ratio of decreasing rows
Threshold of the ratio of decreasing rows
classification_label(str):
one specific label value from the target column
Optional. One specific label value from the target column

Returns:
actual_slices_size:
total number of rows of dataframe
number_of_perturbed_rows:
number of perturbed rows
Length of actual_slice tested
message:
Test result message
metric:
the ratio of decreasing rows over the perturbed rows
The ratio of decreasing rows over the perturbed rows
passed:
TRUE if passed_ratio > threshold
output_df:
dataframe containing the rows whose probability doesn't decrease after perturbation

TRUE if metric > threshold
"""

assert model.model_type != "classification" or classification_label in model.classification_labels, \
Expand Down