Skip to content
Prev Previous commit
Next Next commit
GSK-177 multilabel and target value unavailable scenario unit test added
  • Loading branch information
princyiakov committed Jul 7, 2022
commit f99b3b69fd2d42b2948cd4609918b54d80c4571a
47 changes: 45 additions & 2 deletions giskard-ml-worker/test/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ def test_auc(german_credit_data, german_credit_model):
assert not _test_auc(german_credit_data, german_credit_model, 0.8)


def _test_auc_multilabel(enron_data, enron_model, threshold):
tests = GiskardTestFunctions()
results = tests.performance.test_auc(
actual_slice=enron_data,
model=enron_model,
threshold=threshold
)

assert results.actual_slices_size[0] == 50
assert round(results.metric, 2) == 0.95
return results.passed


def test_auc_multilabel(enron_data, enron_model):
assert _test_auc_multilabel(enron_data, enron_model, 0.5)
assert not _test_auc_multilabel(enron_data, enron_model, 1)


def _test_f1(german_credit_data, german_credit_model, threshold):
tests = GiskardTestFunctions()
results = tests.performance.test_f1(
Expand All @@ -41,6 +59,26 @@ def test_f1(german_credit_data, german_credit_model):
assert not _test_f1(german_credit_data, german_credit_model, 0.9)


def _test_f1_multilabel(enron_data, enron_model, threshold):
tests = GiskardTestFunctions()
results = tests.performance.test_f1(
actual_slice=enron_data,
model=enron_model,
threshold=threshold
)

assert results.actual_slices_size[0] == 50

assert round(results.metric, 2) == 0.68
assert type(results.output_df) is bytes
return results.passed


def test_f1_multilabel(enron_data, enron_model):
assert _test_f1_multilabel(enron_data, enron_model, 0.2)
assert not _test_f1_multilabel(enron_data, enron_model, 0.9)


def _test_precision(german_credit_data, german_credit_model, threshold):
tests = GiskardTestFunctions()
results = tests.performance.test_precision(
Expand Down Expand Up @@ -115,7 +153,7 @@ def _test_rmse(diabetes_dataset_with_target, linear_regression_diabetes, thresho

def test_rmse(diabetes_dataset_with_target, linear_regression_diabetes):
assert not _test_rmse(diabetes_dataset_with_target, linear_regression_diabetes, 52)
assert _test_rmse(diabetes_dataset_with_target, linear_regression_diabetes, 54)
assert _test_rmse(diabetes_dataset_with_target, linear_regression_diabetes, 54)


def _test_mae(diabetes_dataset_with_target, linear_regression_diabetes, threshold=44):
Expand All @@ -134,7 +172,7 @@ def _test_mae(diabetes_dataset_with_target, linear_regression_diabetes, threshol

def test_mae(diabetes_dataset_with_target, linear_regression_diabetes):
assert not _test_mae(diabetes_dataset_with_target, linear_regression_diabetes, 43)
assert _test_mae(diabetes_dataset_with_target, linear_regression_diabetes, 44)
assert _test_mae(diabetes_dataset_with_target, linear_regression_diabetes, 44)


def _test_r2(diabetes_dataset_with_target, linear_regression_diabetes, threshold):
Expand Down Expand Up @@ -305,3 +343,8 @@ def _test_diff_reference_actual_rmse(diabetes_dataset_with_target, linear_regres
def test_diff_reference_actual_rmse(diabetes_dataset_with_target, linear_regression_diabetes):
assert _test_diff_reference_actual_rmse(diabetes_dataset_with_target, linear_regression_diabetes, 0.4)
assert not _test_diff_reference_actual_rmse(diabetes_dataset_with_target, linear_regression_diabetes, 0.01)


def test_recall_exception(enron_test_data, enron_model):
with pytest.raises(Exception):
_test_recall(enron_test_data, enron_model, 0.4)