-
Notifications
You must be signed in to change notification settings - Fork 3
working on perturbation detectors #51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
05a4cc2
cf2a8c4
7ba35b6
6c3ffc9
7e14795
f3d8e32
c30dade
98baa6d
a9fa22f
e9198ce
4dd46b4
e547d4d
a44399d
fe26272
c359c9c
6dca401
99d98dd
6601ecb
14de1fa
182731c
6ba1994
ffbb425
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from giskard_vision.image_classification.tests.performance import Accuracy | ||
| from giskard_vision.landmark_detection.tests.performance import NMEMean | ||
| from giskard_vision.object_detection.tests.performance import IoU | ||
|
|
||
| detector_metrics = { | ||
| "image_classification": Accuracy, | ||
| "landmark": NMEMean, | ||
| "object_detection": IoU, | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| import os | ||
| from abc import abstractmethod | ||
| from pathlib import Path | ||
| from typing import Any, Sequence | ||
|
|
||
| import cv2 | ||
|
|
||
| from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader | ||
| from giskard_vision.core.detectors.base import ( | ||
| DetectorVisionBase, | ||
| IssueGroup, | ||
| ScanResult, | ||
| ) | ||
| from giskard_vision.landmark_detection.tests.base import TestDiff | ||
| from giskard_vision.utils.errors import GiskardImportError | ||
|
|
||
| from .metrics import detector_metrics | ||
|
|
||
| Cropping = IssueGroup( | ||
| "Cropping", description="Cropping involves evaluating the landmark detection model on specific face areas." | ||
| ) | ||
|
|
||
| Ethical = IssueGroup( | ||
| "Ethical", | ||
| description="The data are filtered by ethnicity to detect ethical biases in the landmark detection model.", | ||
| ) | ||
|
|
||
| Pose = IssueGroup( | ||
| "Head Pose", | ||
| description="The data are filtered by head pose to detect biases in the landmark detection model.", | ||
| ) | ||
|
|
||
| Robustness = IssueGroup( | ||
| "Robustness", | ||
| description="Images from the dataset are blurred, recolored and resized to test the robustness of the model to transformations.", | ||
| ) | ||
rabah-khalek marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class PerturbationBaseDetector(DetectorVisionBase): | ||
| """ | ||
| Abstract class for Landmark Detection Detectors | ||
|
|
||
| Methods: | ||
| get_dataloaders(dataset: Any) -> Sequence[Any]: | ||
| Abstract method that returns a list of dataloaders corresponding to | ||
| slices or transformations | ||
|
|
||
| get_results(model: Any, dataset: Any) -> Sequence[ScanResult]: | ||
| Returns a list of ScanResult containing the evaluation results | ||
|
|
||
| get_scan_result(self, test_result) -> ScanResult: | ||
| Convert TestResult to ScanResult | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ... | ||
|
|
||
| def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]: | ||
| dataloaders = self.get_dataloaders(dataset) | ||
|
|
||
| results = [] | ||
| for dl in dataloaders: | ||
| test_result = TestDiff(metric=detector_metrics[model.model_type], threshold=1).run( | ||
| model=model, | ||
| dataloader=dl, | ||
| dataloader_ref=dataset, | ||
| ) | ||
|
|
||
| # Save example images from dataloader and dataset | ||
| current_path = str(Path()) | ||
| os.makedirs(f"{current_path}/examples_images", exist_ok=True) | ||
| filename_examples = [] | ||
|
|
||
| index_worst = 0 if test_result.indexes_examples is None else test_result.indexes_examples[0] | ||
|
|
||
| if isinstance(dl, FilteredDataLoader): | ||
| filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png") | ||
| cv2.imwrite( | ||
| filename_example_dataloader_ref, cv2.resize(dataset[index_worst][0][0], (0, 0), fx=0.3, fy=0.3) | ||
| ) | ||
rabah-khalek marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| filename_examples.append(filename_example_dataloader_ref) | ||
|
|
||
| filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png") | ||
| cv2.imwrite(filename_example_dataloader, cv2.resize(dl[index_worst][0][0], (0, 0), fx=0.3, fy=0.3)) | ||
| filename_examples.append(filename_example_dataloader) | ||
| results.append(self.get_scan_result(test_result, filename_examples, dl.name, len(dl))) | ||
|
|
||
| return results | ||
|
|
||
| def get_scan_result(self, test_result, filename_examples, name, size_data) -> ScanResult: | ||
| try: | ||
| from giskard.scanner.issues import IssueLevel | ||
| except (ImportError, ModuleNotFoundError) as e: | ||
| raise GiskardImportError(["giskard"]) from e | ||
|
|
||
| relative_delta = (test_result.metric_value_test - test_result.metric_value_ref) / test_result.metric_value_ref | ||
|
||
|
|
||
| if relative_delta > self.issue_level_threshold + self.deviation_threshold: | ||
| issue_level = IssueLevel.MAJOR | ||
| elif relative_delta > self.issue_level_threshold: | ||
| issue_level = IssueLevel.MEDIUM | ||
| else: | ||
| issue_level = IssueLevel.MINOR | ||
|
|
||
| return ScanResult( | ||
| name=name, | ||
| metric_name=test_result.metric_name, | ||
| metric_value=test_result.metric_value_test, | ||
| metric_reference_value=test_result.metric_value_ref, | ||
| issue_level=issue_level, | ||
| slice_size=size_data, | ||
| filename_examples=filename_examples, | ||
| relative_delta=relative_delta, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from giskard_vision.core.dataloaders.wrappers import BlurredDataLoader | ||
|
|
||
| from ...core.detectors.decorator import maybe_detector | ||
| from .perturbation import PerturbationBaseDetector, Robustness | ||
|
|
||
|
|
||
| @maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"]) | ||
| class TransformationBlurringDetectorLandmark(PerturbationBaseDetector): | ||
| """ | ||
| Detector that evaluates models performance on blurred images | ||
| """ | ||
|
|
||
| issue_group = Robustness | ||
|
|
||
| def __init__(self, kernel_size=(11, 11), sigma=(3, 3)): | ||
| self.kernel_size = kernel_size | ||
| self.sigma = sigma | ||
|
|
||
| def get_dataloaders(self, dataset): | ||
| dl = BlurredDataLoader(dataset, self.kernel_size, self.sigma) | ||
|
|
||
| dls = [dl] | ||
|
|
||
| return dls |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader | ||
|
|
||
| from ...core.detectors.decorator import maybe_detector | ||
| from .perturbation import PerturbationBaseDetector, Robustness | ||
|
|
||
|
|
||
| @maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"]) | ||
| class TransformationColorDetectorLandmark(PerturbationBaseDetector): | ||
| """ | ||
| Detector that evaluates models performance depending on images in grayscale | ||
| """ | ||
|
|
||
| issue_group = Robustness | ||
|
|
||
| def get_dataloaders(self, dataset): | ||
| dl = ColoredDataLoader(dataset) | ||
|
|
||
| dls = [dl] | ||
|
|
||
| return dls |
Uh oh!
There was an error while loading. Please reload this page.