Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
05a4cc2
working on perturbation detectors
rabah-khalek Aug 8, 2024
cf2a8c4
Refactor HF ppl model to convert numpy array to PIL image
Inokinoki Aug 5, 2024
7ba35b6
Allow to set global mode for an HF ppl model for PIL conversion
Inokinoki Aug 5, 2024
6c3ffc9
mode switch in hf models
rabah-khalek Aug 8, 2024
7e14795
supporting gray scale
rabah-khalek Aug 8, 2024
f3d8e32
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 8, 2024
c30dade
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 10, 2024
98baa6d
added missing predict_rgb_image
rabah-khalek Aug 12, 2024
a9fa22f
ensuring backward compatibility with predict_image
rabah-khalek Aug 12, 2024
e9198ce
updating detectors
rabah-khalek Aug 12, 2024
4dd46b4
Adding noise perturbation detector with Gaussian noise (#52)
bmalezieux Aug 12, 2024
e547d4d
updating detectors
rabah-khalek Aug 12, 2024
a44399d
refactoring detectors
rabah-khalek Aug 13, 2024
fe26272
small updates
rabah-khalek Aug 13, 2024
c359c9c
refactored spec setting
rabah-khalek Aug 13, 2024
6dca401
fixed import in object_detection dataloader
rabah-khalek Aug 13, 2024
99d98dd
renaming pert detectors
rabah-khalek Aug 13, 2024
6601ecb
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 13, 2024
14de1fa
Merge branch 'perturbation-detectors' into refactoring-detectors
rabah-khalek Aug 13, 2024
182731c
fixed import
rabah-khalek Aug 13, 2024
6ba1994
fixing get_scan_results args
rabah-khalek Aug 13, 2024
ffbb425
Merge pull request #53 from Giskard-AI/refactoring-detectors
rabah-khalek Aug 13, 2024
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor HF ppl model to convert numpy array to PIL image
  • Loading branch information
Inokinoki authored and rabah-khalek committed Aug 8, 2024
commit cf2a8c456bc6f145ebd6ae42e7c18d82fe0dcd89
14 changes: 10 additions & 4 deletions giskard_vision/image_classification/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

import numpy as np
from PIL import Image

from giskard_vision.core.models.hf_pipeline import HFPipelineModelBase, HFPipelineTask
from giskard_vision.image_classification.types import Types
Expand Down Expand Up @@ -57,25 +58,30 @@ class SingleLabelImageClassificationHFModelWrapper(ImageClassificationHFModel):
classification_labels: list of classification labels, where the position of the label corresponds to the class index
"""

def predict_probas(self, image: np.ndarray) -> np.ndarray:
def predict_probas(self, image: np.ndarray, mode="RGB") -> np.ndarray:
"""method that takes one image as input and outputs the prediction of probabilities for each class

Args:
image (np.ndarray): input image
mode (str): mode of the image
"""
pil_image = Image.fromarray(image, mode=mode)

# Pipeline takes a PIL image as input
_raw_prediction = self.pipeline(
image,
pil_image,
top_k=len(self.classification_labels), # Get probabilities for all labels
)
_prediction = {p["label"]: p["score"] for p in _raw_prediction}

return np.array([_prediction[label] for label in self.classification_labels])

def predict_image(self, image) -> Types.label:
def predict_image(self, image, mode="RGB") -> Types.label:
"""method that takes one image as input and outputs one class label

Args:
image (np.ndarray): input image
mode (str): mode of the image
"""
probas = self.predict_probas(image)
probas = self.predict_probas(image, mode=mode)
return self.classification_labels[np.argmax(probas)]