Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
implemented @Hartorn's feedback
  • Loading branch information
rabah-khalek committed Dec 14, 2023
commit 1821a0a703aec79ad7a8867607463771a863f84c
20 changes: 9 additions & 11 deletions examples/ex5_models_comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,13 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/rak/Documents/loreal-poc/loreal_poc/tests/performance.py:47: RuntimeWarning: Mean of empty slice\n",
" mes = np.nanmean(es, axis=1)\n",
"/Users/rak/Documents/loreal-poc/loreal_poc/tests/performance.py:85: RuntimeWarning: Mean of empty slice\n",
" return np.nanmean(NMEs.get(prediction_result, marks))\n"
]
Expand Down Expand Up @@ -191,7 +189,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -235,7 +233,7 @@
" <td>False</td>\n",
" <td>bottom half</td>\n",
" <td>OpenCV</td>\n",
" <td>Cropped on bottom half</td>\n",
" <td>300W cropped on bottom half</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
Expand All @@ -246,7 +244,7 @@
" <td>True</td>\n",
" <td>upper half</td>\n",
" <td>OpenCV</td>\n",
" <td>Cropped on upper half</td>\n",
" <td>300W cropped on upper half</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
Expand All @@ -257,12 +255,12 @@
"0 TestDiff NME_mean NaN 1 False bottom half \n",
"1 TestDiff NME_mean 0.040216 1 True upper half \n",
"\n",
" model_name dataloader_name \n",
"0 OpenCV Cropped on bottom half \n",
"1 OpenCV Cropped on upper half "
" model_name dataloader_name \n",
"0 OpenCV 300W cropped on bottom half \n",
"1 OpenCV 300W cropped on upper half "
]
},
"execution_count": 13,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -290,7 +288,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions loreal_poc/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ def _validate_image(cls, image: np.ndarray) -> None:


class DataLoaderWrapper(DataIteratorBase):
def __init__(self, dataloader: DataIteratorBase, name: Optional[str] = None) -> None:
def __init__(self, dataloader: DataIteratorBase) -> None:
self._wrapped_dataloader = dataloader
self.name = name

@property
def name(self):
return f"{self.__class__.__name__}({self._wrapped_dataloader.name})"

def __len__(self) -> int:
return len(self._wrapped_dataloader)
Expand Down
13 changes: 10 additions & 3 deletions loreal_poc/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@ def __init__(
crop_img: bool = True,
crop_marks: bool = True,
) -> None:
name = f"Cropped on {part.name}"
super().__init__(dataloader, name=name)
super().__init__(dataloader)
self._part = part
self._margins = margins
self.crop_img = crop_img
self.crop_marks = crop_marks

@property
def name(self):
return f"{self._wrapped_dataloader.name} cropped on {self._part.name}"

def get_image(self, idx: int) -> np.ndarray:
image = super().get_image(idx)
if not self.crop_img:
Expand All @@ -47,11 +50,15 @@ def __getitem__(self, idx: int) -> Tuple[np.ndarray, Optional[np.ndarray], Optio

class CachedDataLoader(DataLoaderWrapper):
def __init__(self, dataloader: DataIteratorBase, cache_size: int = 20) -> None:
super().__init__(dataloader, name="Cached")
super().__init__(dataloader)
self._max_size: int = cache_size
self._cache_idxs: List[int] = []
self._cache: Dict[int, Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]] = {}

@property
def name(self):
return f"Cached {self._wrapped_dataloader.name}"

def __getitem__(self, idx: int) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]:
# Add basic LRU cache to avoid reloading images and marks on small dataloaders
if idx in self._cache_idxs:
Expand Down