Skip to content
Prev Previous commit
Next Next commit
fixed predict, reproducing main
  • Loading branch information
rabah-khalek committed Jan 3, 2024
commit c78a235e388505d056a29897e615f68e2d0de638
24 changes: 10 additions & 14 deletions loreal_poc/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from logging import getLogger
from time import time
from typing import List, Optional, Union
from typing import List, Optional

import numpy as np

Expand All @@ -19,8 +19,8 @@ class PredictionResult:
prediction_time: float = None


def is_failed(prediction):
return np.isnan(prediction).sum() == prediction.size
def calculate_fail_rate(prediction):
return (np.isnan(prediction).sum(axis=(1, 2)) / prediction[0].size).sum()


class FaceLandmarksModelBase(ABC):
Expand Down Expand Up @@ -65,7 +65,7 @@ def predict_batch(self, images: List[np.ndarray]) -> np.ndarray:
return res

def _postprocessing(
self, batch_prediction: Union[np.ndarray, List[Optional[np.ndarray]]], batch_size: int, facial_part: FacialPart
self, batch_prediction: List[Optional[np.ndarray]], batch_size: int, facial_part: FacialPart
) -> np.ndarray:
"""method that performs postprocessing on single batch prediction

Expand All @@ -76,18 +76,15 @@ def _postprocessing(
Returns:
np.ndarray: single batch image prediction filtered based on landmarks in facial_part
"""
if batch_prediction is None or (hasattr(batch_prediction, "shape") and not batch_prediction.shape):
if all(elt is None for elt in batch_prediction):
res = np.empty((batch_size, self.n_landmarks, self.n_dimensions))
res[:, :, :] = np.nan
elif not hasattr(batch_prediction, "shape") and all([elt is not None for elt in batch_prediction]):
elif all([elt is not None for elt in batch_prediction]):
res = np.array(batch_prediction)
elif not hasattr(batch_prediction, "shape"):
else:
res = np.empty((batch_size, self.n_landmarks, self.n_dimensions))
for i, elt in enumerate(batch_prediction):
if elt is not None:
res[i] = elt if elt is not None else np.nan
else:
res = batch_prediction
res[i] = elt if elt is not None else np.nan
if res.shape[1:] != (self.n_landmarks, self.n_dimensions):
raise ValueError(
f"{self.__class__.__name__}: The array shape expected from predict_batch is ({batch_size}, {self.n_landmarks}, {self.n_dimensions}) but {res.shape} was found."
Expand All @@ -114,10 +111,9 @@ def predict(self, dataloader: DataIteratorBase, facial_part: Optional[FacialPart
for images, _, _ in dataloader:
batch_prediction = self.predict_batch(images)
batch_prediction = self._postprocessing(batch_prediction, len(images), facial_part)
if is_failed(batch_prediction):
prediction_fail_rate += 1
prediction_fail_rate += calculate_fail_rate(batch_prediction)
predictions.append(batch_prediction)
prediction_fail_rate /= len(dataloader)
prediction_fail_rate /= dataloader.flat_len()
te = time()
predictions = np.concatenate(predictions)
if len(predictions.shape) > 3:
Expand Down