-
Notifications
You must be signed in to change notification settings - Fork 3
Add demo image classification models and datasets #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…e-classification-hf
| """ | ||
|
|
||
| def __init__( | ||
| self, hf_id: str, hf_config: Optional[str] = None, hf_split: str = "train", name: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self, hf_id: str, hf_config: Optional[str] = None, hf_split: str = "train", name: Optional[str] = None | |
| self, hf_id: str, hf_config: Optional[str] = None, hf_split: str = "test", name: Optional[str] = None |
wouldn't this make more sense? On the other hand, sometimes there're only "train" split and not "test", but still.
|
|
||
| return MetaData( | ||
| data=flat_meta, | ||
| categories=flat_meta.keys(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would do the opposite, no category features by default
| def __len__(self): | ||
| return len(self.ds) | ||
|
|
||
| def get_meta(self, idx: int) -> Optional[TypesBase.meta]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should put the get_meta on this level, it's not very useful, as we'd need to specify each issuegroup per meta, I would also not have meta_exclude_keys as attribute, it's complicated to understand what it means,
instead I would just implement get_meta for each daughter class, with a custom list of exclude, and custom list of categories
| def predict_image(self, image: np.ndarray) -> np.ndarray: | ||
| """method that takes one image as input and outputs the prediction of probabilities for each class | ||
|
|
||
| Args: | ||
| image (np.ndarray): input image | ||
| """ | ||
| _raw_prediction = self.pipeline( | ||
| image, | ||
| top_k=len(self.classification_labels), # Get probabilities for all labels | ||
| ) | ||
| _prediction = {p["label"]: p["score"] for p in _raw_prediction} | ||
|
|
||
| return np.array([_prediction[label] for label in self.classification_labels]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think since the daughter classes are predicting labels, it's better to call this method predict_probas, and leave the predict_image abstract for this class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| ) | ||
|
|
||
| def predict_image(self, image) -> np.ndarray: | ||
| probas = super().predict_image(image) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would instead use self.predict_probas here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in the single label base class
| return np.array([np.argmax(probas)]) | ||
|
|
||
|
|
||
| class MicrosoftResNetImageNet50HuggingFaceModel(ImageClassificationHuggingFaceModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why no predict_image here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified to have a single label base class
| ) | ||
|
|
||
|
|
||
| class Jsli96ResNetImageNetHuggingFaceModel(ImageClassificationHuggingFaceModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why no predict_image here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Simplified to have a single label base class
| TypesBase, | ||
| ) | ||
|
|
||
| CLASSIFICATION_LABEL_TYPE = np.ndarray # Probabilities for each class |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classification type is not probabilities.
We need to implement a threshold for binary classification models to get the predicted classes
We already have in place the argmax for the multi-classifcation case
In both cases though, I would choose between:
int: label_idstring: label
I'm more in favour of label as it'll be more readable when we use model.predict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed an N-dimensional array to have both multiple labels for one prediction and multiple predictions.
Image classification metrics
…rd-vision into image-classification-hf
|
I aligned all datasets and models to use string of class labels now. There are 2 new notebooks to try all models and datasets on. |
rabah-khalek
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, good job
|
I would just add tests to the core stuff (no need to test the demo dataloaders or models) rather the general wrappers |
Enable Skin Cancer detection with single label.
It takes around 4 minutes to predict 1285 images in
testsplit ofmarmal88/skin_cancerdatasets: