55
66import numpy as np
77
8+ SingleLandmarkData = Tuple [np .ndarray , np .ndarray , Optional [Dict [Any , Any ]]]
9+ BatchedLandmarkData = Tuple [Tuple [np .ndarray ], np .ndarray , Tuple [Optional [Dict [Any , Any ]]]]
10+
811
912class DataIteratorBase (ABC ):
1013 batch_size : int
@@ -61,12 +64,10 @@ def get_meta_with_default(self, idx: int) -> np.ndarray:
6164 marks = marks if marks is not None else self .meta_none ()
6265 return marks
6366
64- def get_single_element (self , idx ) -> Tuple [ np . ndarray , Optional [ np . ndarray ], Optional [ Dict [ Any , Any ]]] :
67+ def get_single_element (self , idx ) -> SingleLandmarkData :
6568 return self .get_image (idx ), self .get_marks_with_default (idx ), self .get_meta_with_default (idx )
6669
67- def __getitem__ (
68- self , idx : int
69- ) -> Tuple [np .ndarray , Optional [np .ndarray ], Optional [Dict [Any , Any ]]]: # (image, marks, meta)
70+ def __getitem__ (self , idx : int ) -> BatchedLandmarkData :
7071 return self ._collate_fn (
7172 [self .get_single_element (i ) for i in self .idx_sampler [idx * self .batch_size : (idx + 1 ) * self .batch_size ]]
7273 )
@@ -84,30 +85,22 @@ def all_marks(self) -> np.ndarray: # (marks)
8485 def all_meta (self ) -> List : # (meta)
8586 return [self .get_meta_with_default (idx ) for idx in self .idx_sampler ]
8687
87- def __next__ (self ) -> Tuple [ np . ndarray , np . ndarray ] :
88+ def __next__ (self ) -> BatchedLandmarkData :
8889 if self .idx >= len (self ):
8990 raise StopIteration
9091 elt = self [self .idx ]
9192 self .idx += 1
9293 return elt
9394
94- def _collate_fn (
95- self , elements : List [Tuple [np .ndarray , Optional [np .ndarray ], Optional [Dict [Any , Any ]]]]
96- ) -> Tuple [np .ndarray , Optional [np .ndarray ], Optional [Dict [Any , Any ]]]:
97- batched_elements = list (zip (* elements ))
98- batched_elements [1 ] = np .array (batched_elements [1 ])
99-
95+ def _collate_fn (self , elements : List [SingleLandmarkData ]) -> BatchedLandmarkData :
96+ batched_elements = list (zip (* elements , strict = True ))
10097 # INFO: Restore if we want to concatenate all meta under one dict instead of keeping them as records (list of dicts)
10198 # meta_keys = next((list(elt.keys()) for elt in batched_elements[2] if elt is not None), [])
10299 # batched_elements[2] = {
103100 # key: [meta[key] if (meta is not None and key in meta) else None for meta in batched_elements[2]]
104101 # for key in meta_keys
105102 # }
106-
107- # if len(batched_elements[0]) != self.batch_size:
108- # raise StopIteration
109-
110- return batched_elements
103+ return batched_elements [0 ], np .array (batched_elements [1 ]), batched_elements [2 ]
111104
112105
113106class DataLoaderBase (DataIteratorBase ):
0 commit comments