1212from lib .utils .box import bbox_overlaps
1313
1414class vg_hdf5 (Dataset ):
15- def __init__ (self , cfg , split = "train" , transforms = None , num_im = - 1 ):
15+ def __init__ (self , cfg , split = "train" , transforms = None , num_im = - 1 , num_val_im = 5000 ,
16+ filter_duplicate_rels = True , filter_non_overlap = True , filter_empty_rels = True ):
1617 assert split == "train" or split == "test" , "split must be one of [train, val, test]"
1718 assert num_im >= - 1 , "the number of samples must be >= 0"
19+
20+ self .data_dir = cfg .DATASET .PATH
1821 self .transforms = transforms
19- self .data_dir = "datasets/vg_bm"
22+
23+ self .split = split
24+ self .filter_non_overlap = filter_non_overlap
25+ self .filter_duplicate_rels = filter_duplicate_rels and self .split == 'train'
26+
2027 self .roidb_file = os .path .join (self .data_dir , "VG-SGG.h5" )
2128 self .image_file = os .path .join (self .data_dir , "imdb_1024.h5" )
2229 # read in dataset from a h5 file and a dict (json) file
2330 assert os .path .exists (self .data_dir ), \
2431 "cannot find folder {}, please download the visual genome data into this folder" .format (self .data_dir )
2532 self .im_h5 = h5py .File (self .image_file , 'r' )
26- self .roi_h5 = h5py .File (os .path .join (self .data_dir , "VG-SGG.h5" ), 'r' )
2733 self .info = json .load (open (os .path .join (self .data_dir , "VG-SGG-dicts.json" ), 'r' ))
28-
2934 self .im_refs = self .im_h5 ['images' ] # image data reference
3035 im_scale = self .im_refs .shape [2 ]
3136
32- print ('split=' + split )
33- data_split = self .roi_h5 ['split' ][:]
34-
35- self .split = split
36- if split == "train" or split == "test" :
37- split_label = 0 if split == "train" else 2
38- split_mask = data_split == split_label # current split
39- else : # -1
40- split_mask = data_split >= 0 # all
41- # get rid of images that do not have box
42- valid_mask = self .roi_h5 ['img_to_first_box' ][:] >= 0
43- valid_mask = np .bitwise_and (split_mask , valid_mask )
44- self .image_index = np .where (valid_mask )[0 ] # split index
45-
46- if num_im > - 1 :
47- self .image_index = self .image_index [:num_im ]
48-
49- # override split mask
50- split_mask = np .zeros_like (data_split ).astype (bool )
51- split_mask [self .image_index ] = True # build a split mask
52- # if use all images
53- self .im_sizes = np .vstack ([self .im_h5 ['image_widths' ][split_mask ],
54- self .im_h5 ['image_heights' ][split_mask ]]).transpose ()
55-
56- # h5 file is in 1-based index
57- self .im_to_first_box = self .roi_h5 ['img_to_first_box' ][split_mask ]
58- self .im_to_last_box = self .roi_h5 ['img_to_last_box' ][split_mask ]
59- self .all_boxes = self .roi_h5 ['boxes_%i' % im_scale ][:] # will index later
60- self .all_boxes [:, :2 ] = self .all_boxes [:, :2 ]
61- assert (np .all (self .all_boxes [:, :2 ] >= 0 )) # sanity check
62- assert (np .all (self .all_boxes [:, 2 :] > 0 )) # no empty box
63-
64-
65- # convert from xc, yc, w, h to x1, y1, x2, y2
66- self .all_boxes [:, :2 ] = self .all_boxes [:, :2 ] - self .all_boxes [:, 2 :]/ 2
67- self .all_boxes [:, 2 :] = self .all_boxes [:, :2 ] + self .all_boxes [:, 2 :]
68- self .labels = self .roi_h5 ['labels' ][:,0 ]
69-
7037 # add background class
7138 self .info ['label_to_idx' ]['__background__' ] = 0
7239 self .class_to_ind = self .info ['label_to_idx' ]
7340 self .ind_to_classes = sorted (self .class_to_ind , key = lambda k :
7441 self .class_to_ind [k ])
7542 # cfg.ind_to_class = self.ind_to_classes
7643
77- # load relation labels
78- self .im_to_first_rel = self .roi_h5 ['img_to_first_rel' ][split_mask ]
79- self .im_to_last_rel = self .roi_h5 ['img_to_last_rel' ][split_mask ]
80- self ._relations = self .roi_h5 ['relationships' ][:]
81- self ._relation_predicates = self .roi_h5 ['predicates' ][:,0 ]
82- assert (self .im_to_first_rel .shape [0 ] == self .im_to_last_rel .shape [0 ])
83- assert (self ._relations .shape [0 ] == self ._relation_predicates .shape [0 ]) # sanity check
8444 self .predicate_to_ind = self .info ['predicate_to_idx' ]
8545 self .predicate_to_ind ['__background__' ] = 0
8646 self .ind_to_predicates = sorted (self .predicate_to_ind , key = lambda k :
8747 self .predicate_to_ind [k ])
88-
8948 # cfg.ind_to_predicate = self.ind_to_predicates
9049
91-
9250 self .split_mask , self .image_index , self .im_sizes , self .gt_boxes , self .gt_classes , self .relationships = load_graphs (
9351 self .roidb_file , self .image_file ,
94- self .split , num_im , num_val_im = 5000 ,
95- filter_empty_rels = True ,
96- filter_non_overlap = False and split == "train" ,
52+ self .split , num_im , num_val_im = num_val_im ,
53+ filter_empty_rels = filter_empty_rels ,
54+ filter_non_overlap = filter_non_overlap and split == "train" ,
9755 )
9856
9957 self .json_category_id_to_contiguous_id = self .class_to_ind
@@ -102,8 +60,6 @@ def __init__(self, cfg, split="train", transforms=None, num_im=-1):
10260 v : k for k , v in self .json_category_id_to_contiguous_id .items ()
10361 }
10462
105- # self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
106-
10763 @property
10864 def coco (self ):
10965 """
@@ -142,36 +98,6 @@ def _im_getter(self, idx):
14298 def __len__ (self ):
14399 return len (self .image_index )
144100
145- # def __getitem__(self, index):
146- # """
147- # get dataset item
148- # """
149- # i = index; assert(self.im_to_first_box[i] >= 0)
150- # # get image
151- # img = Image.fromarray(self._im_getter(i)); width, height = img.size
152- #
153- # # get object bounding boxes, labels and relations
154- # obj_boxes = self.all_boxes[self.im_to_first_box[i]:self.im_to_last_box[i]+1,:]
155- # obj_labels = self.labels[self.im_to_first_box[i]:self.im_to_last_box[i]+1]
156- # obj_relations = np.zeros((obj_boxes.shape[0], obj_boxes.shape[0]))
157- # if self.im_to_first_rel[i] >= 0: # if image has relations
158- # predicates = self._relation_predicates[self.im_to_first_rel[i]
159- # :self.im_to_last_rel[i]+1]
160- # obj_idx = self._relations[self.im_to_first_rel[i]
161- # :self.im_to_last_rel[i]+1]
162- # obj_idx = obj_idx - self.im_to_first_box[i]
163- # assert(np.all(obj_idx>=0) and np.all(obj_idx<obj_boxes.shape[0])) # sanity check
164- # for j, p in enumerate(predicates):
165- # # gt_relations.append([obj_idx[j][0], obj_idx[j][1], p])
166- # obj_relations[obj_idx[j][0], obj_idx[j][1]] = p
167- #
168- # target_raw = BoxList(obj_boxes, (width, height), mode="xyxy")
169- # img, target = self.transforms(img, target_raw)
170- # target.add_field("labels", torch.from_numpy(obj_labels))
171- # target.add_field("pred_labels", torch.from_numpy(obj_relations))
172- # target = target.clip_to_image(remove_empty=False)
173- # return img, target, index
174-
175101 def __getitem__ (self , index ):
176102 """
177103 get dataset item
@@ -184,6 +110,16 @@ def __getitem__(self, index):
184110 obj_labels = self .gt_classes [index ].copy ()
185111 obj_relation_triplets = self .relationships [index ].copy ()
186112
113+ if self .filter_duplicate_rels :
114+ # Filter out dupes!
115+ assert self .split == 'train'
116+ old_size = obj_relation_triplets .shape [0 ]
117+ all_rel_sets = defaultdict (list )
118+ for (o0 , o1 , r ) in obj_relation_triplets :
119+ all_rel_sets [(o0 , o1 )].append (r )
120+ obj_relation_triplets = [(k [0 ], k [1 ], np .random .choice (v )) for k ,v in all_rel_sets .items ()]
121+ obj_relation_triplets = np .array (obj_relation_triplets )
122+
187123 obj_relations = np .zeros ((obj_boxes .shape [0 ], obj_boxes .shape [0 ]))
188124
189125 for i in range (obj_relation_triplets .shape [0 ]):
@@ -209,6 +145,15 @@ def get_groundtruth(self, index):
209145 obj_labels = self .gt_classes [index ].copy ()
210146 obj_relation_triplets = self .relationships [index ].copy ()
211147
148+ if self .filter_duplicate_rels :
149+ # Filter out dupes!
150+ assert self .split == 'train'
151+ old_size = obj_relation_triplets .shape [0 ]
152+ all_rel_sets = defaultdict (list )
153+ for (o0 , o1 , r ) in obj_relation_triplets :
154+ all_rel_sets [(o0 , o1 )].append (r )
155+ obj_relation_triplets = [(k [0 ], k [1 ], np .random .choice (v )) for k ,v in all_rel_sets .items ()]
156+ obj_relation_triplets = np .array (obj_relation_triplets )
212157
213158 obj_relations = np .zeros ((obj_boxes .shape [0 ], obj_boxes .shape [0 ]))
214159
@@ -229,35 +174,6 @@ def get_img_info(self, img_id):
229174 w , h = self .im_sizes [img_id , :]
230175 return {"height" : h , "width" : w }
231176
232- # def get_groundtruth(self, index):
233- # i = index; assert(self.im_to_first_box[i] >= 0)
234- # width, height = self.im_sizes[i, :]
235- # # get object bounding boxes, labels and relations
236- # obj_boxes = self.all_boxes[self.im_to_first_box[i]:self.im_to_last_box[i]+1,:]
237- # obj_labels = self.labels[self.im_to_first_box[i]:self.im_to_last_box[i]+1]
238- # obj_relations = np.zeros((obj_boxes.shape[0], obj_boxes.shape[0]))
239- # obj_relation_triplets = np.zeros((self.im_to_last_rel[i] - self.im_to_first_rel[i] + 1, 3))
240- # if self.im_to_first_rel[i] >= 0: # if image has relations
241- # predicates = self._relation_predicates[self.im_to_first_rel[i]
242- # :self.im_to_last_rel[i]+1]
243- # obj_idx = self._relations[self.im_to_first_rel[i]
244- # :self.im_to_last_rel[i]+1]
245- # obj_idx = obj_idx - self.im_to_first_box[i]
246- # assert(np.all(obj_idx>=0) and np.all(obj_idx<obj_boxes.shape[0])) # sanity check
247- # for j, p in enumerate(predicates):
248- # # gt_relations.append([obj_idx[j][0], obj_idx[j][1], p])
249- # obj_relations[obj_idx[j][0], obj_idx[j][1]] = p
250- # obj_relation_triplets[j][0] = obj_idx[j][0]
251- # obj_relation_triplets[j][1] = obj_idx[j][1]
252- # obj_relation_triplets[j][2] = p
253- #
254- # target = BoxList(obj_boxes, (width, height), mode="xyxy")
255- # target.add_field("labels", torch.from_numpy(obj_labels))
256- # target.add_field("pred_labels", torch.from_numpy(obj_relations))
257- # target.add_field("relation_labels", torch.from_numpy(obj_relation_triplets))
258- # target.add_field("difficult", torch.from_numpy(obj_labels).clone().fill_(0))
259- # return target
260-
261177 def map_class_id_to_class_name (self , class_id ):
262178 return self .ind_to_classes [class_id ]
263179
@@ -353,7 +269,7 @@ def load_graphs(graphs_file, images_file, mode='train', num_im=-1, num_val_im=0,
353269
354270 if filter_non_overlap :
355271 assert mode == 'train'
356- inters = bbox_overlaps (torch .from_numpy (boxes_i ), torch .from_numpy (boxes_i )).numpy ()
272+ inters = bbox_overlaps (torch .from_numpy (boxes_i ). float () , torch .from_numpy (boxes_i ). float ( )).numpy ()
357273 rel_overs = inters [rels [:, 0 ], rels [:, 1 ]]
358274 inc = np .where (rel_overs > 0.0 )[0 ]
359275
0 commit comments