Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
05a4cc2
working on perturbation detectors
rabah-khalek Aug 8, 2024
cf2a8c4
Refactor HF ppl model to convert numpy array to PIL image
Inokinoki Aug 5, 2024
7ba35b6
Allow to set global mode for an HF ppl model for PIL conversion
Inokinoki Aug 5, 2024
6c3ffc9
mode switch in hf models
rabah-khalek Aug 8, 2024
7e14795
supporting gray scale
rabah-khalek Aug 8, 2024
f3d8e32
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 8, 2024
c30dade
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 10, 2024
98baa6d
added missing predict_rgb_image
rabah-khalek Aug 12, 2024
a9fa22f
ensuring backward compatibility with predict_image
rabah-khalek Aug 12, 2024
e9198ce
updating detectors
rabah-khalek Aug 12, 2024
4dd46b4
Adding noise perturbation detector with Gaussian noise (#52)
bmalezieux Aug 12, 2024
e547d4d
updating detectors
rabah-khalek Aug 12, 2024
a44399d
refactoring detectors
rabah-khalek Aug 13, 2024
fe26272
small updates
rabah-khalek Aug 13, 2024
c359c9c
refactored spec setting
rabah-khalek Aug 13, 2024
6dca401
fixed import in object_detection dataloader
rabah-khalek Aug 13, 2024
99d98dd
renaming pert detectors
rabah-khalek Aug 13, 2024
6601ecb
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 13, 2024
14de1fa
Merge branch 'perturbation-detectors' into refactoring-detectors
rabah-khalek Aug 13, 2024
182731c
fixed import
rabah-khalek Aug 13, 2024
6ba1994
fixing get_scan_results args
rabah-khalek Aug 13, 2024
ffbb425
Merge pull request #53 from Giskard-AI/refactoring-detectors
rabah-khalek Aug 13, 2024
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Allow to set global mode for an HF ppl model for PIL conversion
  • Loading branch information
Inokinoki authored and rabah-khalek committed Aug 8, 2024
commit 7ba35b6ef6e04a5d7e2352fc18b863b2094cf4db
6 changes: 5 additions & 1 deletion giskard_vision/core/dataloaders/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ class HFDataLoader(DataIteratorBase):
"""

def __init__(
self, hf_id: str, hf_config: Optional[str] = None, hf_split: str = "test", name: Optional[str] = None
self,
hf_id: str,
hf_config: Optional[str] = None,
hf_split: str = "test",
name: Optional[str] = None,
) -> None:
"""
Initializes the general HuggingFace Datasets instance.
Expand Down
6 changes: 3 additions & 3 deletions giskard_vision/core/models/hf_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def __init__(
"""init method that accepts a model object, number of landmarks and dimensions

Args:
model_id (str): Hugging Face model ID
name (Optional[str]): name of the model
pipeline_task (HFPipelineTask): HuggingFace pipeline task
model_id (str): Hugging Face model ID.
name (Optional[str]): name of the model.
pipeline_task (HFPipelineTask): HuggingFace pipeline task.

Raises:
GiskardImportError: If there are missing Hugging Face dependencies.
Expand Down
24 changes: 14 additions & 10 deletions giskard_vision/image_classification/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ class ImageClassificationHFModel(HFPipelineModelBase):
"""Hugging Face pipeline wrapper class that serves as a template for image classification predictions

Args:
model_id (str): Hugging Face model ID
name (Optional[str]): name of the model
device (str): device to run the model on
model_id (str): Hugging Face model ID.
name (Optional[str]): name of the model.
device (str): device to run the model on.
mode (str): The mode to convert the numpy image data to PIL image, defaulting to "RGB".

Attributes:
classification_labels: list of classification labels, where the position of the label corresponds to the class index
Expand All @@ -22,13 +23,14 @@ class ImageClassificationHFModel(HFPipelineModelBase):
model_type = "image_classification"
prediction_result_cls = Types.prediction_result

def __init__(self, model_id: str, name: Optional[str] = None, device: str = "cpu"):
def __init__(self, model_id: str, name: Optional[str] = None, device: str = "cpu", mode: str = "RGB"):
"""init method that accepts a model id, name and device

Args:
model_id (str): Hugging Face model ID
name (Optional[str]): name of the model
device (str): device to run the model on
model_id (str): Hugging Face model ID.
name (Optional[str]): name of the model.
device (str): device to run the model on.
mode (str): The mode to convert the numpy image data to PIL image, defaulting to "RGB".
"""

super().__init__(
Expand All @@ -39,6 +41,7 @@ def __init__(self, model_id: str, name: Optional[str] = None, device: str = "cpu
)

self._classification_labels = list(self.pipeline.model.config.id2label.values())
self._mode = mode

@property
def classification_labels(self):
Expand All @@ -58,14 +61,15 @@ class SingleLabelImageClassificationHFModelWrapper(ImageClassificationHFModel):
classification_labels: list of classification labels, where the position of the label corresponds to the class index
"""

def predict_probas(self, image: np.ndarray, mode="RGB") -> np.ndarray:
def predict_probas(self, image: np.ndarray, mode=None) -> np.ndarray:
"""method that takes one image as input and outputs the prediction of probabilities for each class

Args:
image (np.ndarray): input image
mode (str): mode of the image
"""
pil_image = Image.fromarray(image, mode=mode)
m = mode or self._mode
pil_image = Image.fromarray(image, mode=m)

# Pipeline takes a PIL image as input
_raw_prediction = self.pipeline(
Expand All @@ -76,7 +80,7 @@ def predict_probas(self, image: np.ndarray, mode="RGB") -> np.ndarray:

return np.array([_prediction[label] for label in self.classification_labels])

def predict_image(self, image, mode="RGB") -> Types.label:
def predict_image(self, image, mode=None) -> Types.label:
"""method that takes one image as input and outputs one class label

Args:
Expand Down