Skip to content

Commit 769f4f4

Browse files
committed
minor change for computing frequency prior
1 parent 99f8384 commit 769f4f4

File tree

6 files changed

+118
-157
lines changed

6 files changed

+118
-157
lines changed

‎README.md‎

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ The goal of gathering all these representative methods into a single repo is to
3535
- [x] Scene Graph Generation Baseline (:balloon: 2019-07-06)
3636
- [x] Iterative Message Passing (IMP) (:balloon: 2019-07-07)
3737
- [ ] Multi-level Scene Description Network (MSDN)
38-
- [x] Neural Motif (Frequency Prior) (:balloon: 2019-07-08)
38+
- [x] Neural Motif (Frequency Prior Baseline) (:balloon: 2019-07-08)
3939
- [ ] Neural Motif
4040
- [ ] Graph R-CNN
41+
42+
## Benchmarking
43+
44+
### Object Detection
45+
46+
backbone | model | #GPUs | batch size | base_lr | lr_decay_step | max_iter | mAP@0.5 | mAP@0.50:0.95
47+
--------|--------|--------|--------|---------|--------|--------|--------|---------
48+
Res101 | faster r-cnn | 6 | 6 | 5e-3 | (70k,90k) | 100k | - | -
49+
50+
### Scene Graph Generation
51+
backbone | model | #GPUs | batch size | base_lr | lr_decay_step | max_iter | sgdet@20 | sgdet@50 | sgdet@100
52+
--------|--------|--------|---------|--------|--------|--------|---------|---------
53+
Res101 | vanilla | 6 | 6 | 5e-3 | (70k,90k) | 100k | - | - | -

‎configs/faster_rcnn_res101.yaml‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
DATASET:
22
NAME: "vg"
33
MODE: "benchmark"
4+
PATH: "datasets/vg_bm"
45
TRAIN_BATCH_SIZE: 6
56
TEST_BATCH_SIZE: 1
67
MODEL:

‎lib/config/defaults.py‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
_C.DATASET = CN()
1111
_C.DATASET.NAME = "vg"
1212
_C.DATASET.MODE = "benchmark" # dataset mode, benchmark | 1600-400-400 | 2500-600-400, etc
13+
_C.DATASET.PATH = "datasets/vg_bm"
1314
_C.DATASET.LOADER = 'object' # which kind of data loader to use, object | object+attribute | object+attribute+relationship
1415
_C.DATASET.TRAIN_BATCH_SIZE = 4
1516
_C.DATASET.TEST_BATCH_SIZE = 4

‎lib/data/vg_hdf5.py‎

Lines changed: 32 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -12,88 +12,46 @@
1212
from lib.utils.box import bbox_overlaps
1313

1414
class vg_hdf5(Dataset):
15-
def __init__(self, cfg, split="train", transforms=None, num_im=-1):
15+
def __init__(self, cfg, split="train", transforms=None, num_im=-1, num_val_im=5000,
16+
filter_duplicate_rels=True, filter_non_overlap=True, filter_empty_rels=True):
1617
assert split == "train" or split == "test", "split must be one of [train, val, test]"
1718
assert num_im >= -1, "the number of samples must be >= 0"
19+
20+
self.data_dir = cfg.DATASET.PATH
1821
self.transforms = transforms
19-
self.data_dir = "datasets/vg_bm"
22+
23+
self.split = split
24+
self.filter_non_overlap = filter_non_overlap
25+
self.filter_duplicate_rels = filter_duplicate_rels and self.split == 'train'
26+
2027
self.roidb_file = os.path.join(self.data_dir, "VG-SGG.h5")
2128
self.image_file = os.path.join(self.data_dir, "imdb_1024.h5")
2229
# read in dataset from a h5 file and a dict (json) file
2330
assert os.path.exists(self.data_dir), \
2431
"cannot find folder {}, please download the visual genome data into this folder".format(self.data_dir)
2532
self.im_h5 = h5py.File(self.image_file, 'r')
26-
self.roi_h5 = h5py.File(os.path.join(self.data_dir, "VG-SGG.h5"), 'r')
2733
self.info = json.load(open(os.path.join(self.data_dir, "VG-SGG-dicts.json"), 'r'))
28-
2934
self.im_refs = self.im_h5['images'] # image data reference
3035
im_scale = self.im_refs.shape[2]
3136

32-
print('split=' + split)
33-
data_split = self.roi_h5['split'][:]
34-
35-
self.split = split
36-
if split == "train" or split == "test":
37-
split_label = 0 if split == "train" else 2
38-
split_mask = data_split == split_label # current split
39-
else: # -1
40-
split_mask = data_split >= 0 # all
41-
# get rid of images that do not have box
42-
valid_mask = self.roi_h5['img_to_first_box'][:] >= 0
43-
valid_mask = np.bitwise_and(split_mask, valid_mask)
44-
self.image_index = np.where(valid_mask)[0] # split index
45-
46-
if num_im > -1:
47-
self.image_index = self.image_index[:num_im]
48-
49-
# override split mask
50-
split_mask = np.zeros_like(data_split).astype(bool)
51-
split_mask[self.image_index] = True # build a split mask
52-
# if use all images
53-
self.im_sizes = np.vstack([self.im_h5['image_widths'][split_mask],
54-
self.im_h5['image_heights'][split_mask]]).transpose()
55-
56-
# h5 file is in 1-based index
57-
self.im_to_first_box = self.roi_h5['img_to_first_box'][split_mask]
58-
self.im_to_last_box = self.roi_h5['img_to_last_box'][split_mask]
59-
self.all_boxes = self.roi_h5['boxes_%i' % im_scale][:] # will index later
60-
self.all_boxes[:, :2] = self.all_boxes[:, :2]
61-
assert(np.all(self.all_boxes[:, :2] >= 0)) # sanity check
62-
assert(np.all(self.all_boxes[:, 2:] > 0)) # no empty box
63-
64-
65-
# convert from xc, yc, w, h to x1, y1, x2, y2
66-
self.all_boxes[:, :2] = self.all_boxes[:, :2] - self.all_boxes[:, 2:]/2
67-
self.all_boxes[:, 2:] = self.all_boxes[:, :2] + self.all_boxes[:, 2:]
68-
self.labels = self.roi_h5['labels'][:,0]
69-
7037
# add background class
7138
self.info['label_to_idx']['__background__'] = 0
7239
self.class_to_ind = self.info['label_to_idx']
7340
self.ind_to_classes = sorted(self.class_to_ind, key=lambda k:
7441
self.class_to_ind[k])
7542
# cfg.ind_to_class = self.ind_to_classes
7643

77-
# load relation labels
78-
self.im_to_first_rel = self.roi_h5['img_to_first_rel'][split_mask]
79-
self.im_to_last_rel = self.roi_h5['img_to_last_rel'][split_mask]
80-
self._relations = self.roi_h5['relationships'][:]
81-
self._relation_predicates = self.roi_h5['predicates'][:,0]
82-
assert(self.im_to_first_rel.shape[0] == self.im_to_last_rel.shape[0])
83-
assert(self._relations.shape[0] == self._relation_predicates.shape[0]) # sanity check
8444
self.predicate_to_ind = self.info['predicate_to_idx']
8545
self.predicate_to_ind['__background__'] = 0
8646
self.ind_to_predicates = sorted(self.predicate_to_ind, key=lambda k:
8747
self.predicate_to_ind[k])
88-
8948
# cfg.ind_to_predicate = self.ind_to_predicates
9049

91-
9250
self.split_mask, self.image_index, self.im_sizes, self.gt_boxes, self.gt_classes, self.relationships = load_graphs(
9351
self.roidb_file, self.image_file,
94-
self.split, num_im, num_val_im=5000,
95-
filter_empty_rels=True,
96-
filter_non_overlap=False and split == "train",
52+
self.split, num_im, num_val_im=num_val_im,
53+
filter_empty_rels=filter_empty_rels,
54+
filter_non_overlap=filter_non_overlap and split == "train",
9755
)
9856

9957
self.json_category_id_to_contiguous_id = self.class_to_ind
@@ -102,8 +60,6 @@ def __init__(self, cfg, split="train", transforms=None, num_im=-1):
10260
v: k for k, v in self.json_category_id_to_contiguous_id.items()
10361
}
10462

105-
# self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
106-
10763
@property
10864
def coco(self):
10965
"""
@@ -142,36 +98,6 @@ def _im_getter(self, idx):
14298
def __len__(self):
14399
return len(self.image_index)
144100

145-
# def __getitem__(self, index):
146-
# """
147-
# get dataset item
148-
# """
149-
# i = index; assert(self.im_to_first_box[i] >= 0)
150-
# # get image
151-
# img = Image.fromarray(self._im_getter(i)); width, height = img.size
152-
#
153-
# # get object bounding boxes, labels and relations
154-
# obj_boxes = self.all_boxes[self.im_to_first_box[i]:self.im_to_last_box[i]+1,:]
155-
# obj_labels = self.labels[self.im_to_first_box[i]:self.im_to_last_box[i]+1]
156-
# obj_relations = np.zeros((obj_boxes.shape[0], obj_boxes.shape[0]))
157-
# if self.im_to_first_rel[i] >= 0: # if image has relations
158-
# predicates = self._relation_predicates[self.im_to_first_rel[i]
159-
# :self.im_to_last_rel[i]+1]
160-
# obj_idx = self._relations[self.im_to_first_rel[i]
161-
# :self.im_to_last_rel[i]+1]
162-
# obj_idx = obj_idx - self.im_to_first_box[i]
163-
# assert(np.all(obj_idx>=0) and np.all(obj_idx<obj_boxes.shape[0])) # sanity check
164-
# for j, p in enumerate(predicates):
165-
# # gt_relations.append([obj_idx[j][0], obj_idx[j][1], p])
166-
# obj_relations[obj_idx[j][0], obj_idx[j][1]] = p
167-
#
168-
# target_raw = BoxList(obj_boxes, (width, height), mode="xyxy")
169-
# img, target = self.transforms(img, target_raw)
170-
# target.add_field("labels", torch.from_numpy(obj_labels))
171-
# target.add_field("pred_labels", torch.from_numpy(obj_relations))
172-
# target = target.clip_to_image(remove_empty=False)
173-
# return img, target, index
174-
175101
def __getitem__(self, index):
176102
"""
177103
get dataset item
@@ -184,6 +110,16 @@ def __getitem__(self, index):
184110
obj_labels = self.gt_classes[index].copy()
185111
obj_relation_triplets = self.relationships[index].copy()
186112

113+
if self.filter_duplicate_rels:
114+
# Filter out dupes!
115+
assert self.split == 'train'
116+
old_size = obj_relation_triplets.shape[0]
117+
all_rel_sets = defaultdict(list)
118+
for (o0, o1, r) in obj_relation_triplets:
119+
all_rel_sets[(o0, o1)].append(r)
120+
obj_relation_triplets = [(k[0], k[1], np.random.choice(v)) for k,v in all_rel_sets.items()]
121+
obj_relation_triplets = np.array(obj_relation_triplets)
122+
187123
obj_relations = np.zeros((obj_boxes.shape[0], obj_boxes.shape[0]))
188124

189125
for i in range(obj_relation_triplets.shape[0]):
@@ -209,6 +145,15 @@ def get_groundtruth(self, index):
209145
obj_labels = self.gt_classes[index].copy()
210146
obj_relation_triplets = self.relationships[index].copy()
211147

148+
if self.filter_duplicate_rels:
149+
# Filter out dupes!
150+
assert self.split == 'train'
151+
old_size = obj_relation_triplets.shape[0]
152+
all_rel_sets = defaultdict(list)
153+
for (o0, o1, r) in obj_relation_triplets:
154+
all_rel_sets[(o0, o1)].append(r)
155+
obj_relation_triplets = [(k[0], k[1], np.random.choice(v)) for k,v in all_rel_sets.items()]
156+
obj_relation_triplets = np.array(obj_relation_triplets)
212157

213158
obj_relations = np.zeros((obj_boxes.shape[0], obj_boxes.shape[0]))
214159

@@ -229,35 +174,6 @@ def get_img_info(self, img_id):
229174
w, h = self.im_sizes[img_id, :]
230175
return {"height": h, "width": w}
231176

232-
# def get_groundtruth(self, index):
233-
# i = index; assert(self.im_to_first_box[i] >= 0)
234-
# width, height = self.im_sizes[i, :]
235-
# # get object bounding boxes, labels and relations
236-
# obj_boxes = self.all_boxes[self.im_to_first_box[i]:self.im_to_last_box[i]+1,:]
237-
# obj_labels = self.labels[self.im_to_first_box[i]:self.im_to_last_box[i]+1]
238-
# obj_relations = np.zeros((obj_boxes.shape[0], obj_boxes.shape[0]))
239-
# obj_relation_triplets = np.zeros((self.im_to_last_rel[i] - self.im_to_first_rel[i] + 1, 3))
240-
# if self.im_to_first_rel[i] >= 0: # if image has relations
241-
# predicates = self._relation_predicates[self.im_to_first_rel[i]
242-
# :self.im_to_last_rel[i]+1]
243-
# obj_idx = self._relations[self.im_to_first_rel[i]
244-
# :self.im_to_last_rel[i]+1]
245-
# obj_idx = obj_idx - self.im_to_first_box[i]
246-
# assert(np.all(obj_idx>=0) and np.all(obj_idx<obj_boxes.shape[0])) # sanity check
247-
# for j, p in enumerate(predicates):
248-
# # gt_relations.append([obj_idx[j][0], obj_idx[j][1], p])
249-
# obj_relations[obj_idx[j][0], obj_idx[j][1]] = p
250-
# obj_relation_triplets[j][0] = obj_idx[j][0]
251-
# obj_relation_triplets[j][1] = obj_idx[j][1]
252-
# obj_relation_triplets[j][2] = p
253-
#
254-
# target = BoxList(obj_boxes, (width, height), mode="xyxy")
255-
# target.add_field("labels", torch.from_numpy(obj_labels))
256-
# target.add_field("pred_labels", torch.from_numpy(obj_relations))
257-
# target.add_field("relation_labels", torch.from_numpy(obj_relation_triplets))
258-
# target.add_field("difficult", torch.from_numpy(obj_labels).clone().fill_(0))
259-
# return target
260-
261177
def map_class_id_to_class_name(self, class_id):
262178
return self.ind_to_classes[class_id]
263179

@@ -353,7 +269,7 @@ def load_graphs(graphs_file, images_file, mode='train', num_im=-1, num_val_im=0,
353269

354270
if filter_non_overlap:
355271
assert mode == 'train'
356-
inters = bbox_overlaps(torch.from_numpy(boxes_i), torch.from_numpy(boxes_i)).numpy()
272+
inters = bbox_overlaps(torch.from_numpy(boxes_i).float(), torch.from_numpy(boxes_i).float()).numpy()
357273
rel_overs = inters[rels[:, 0], rels[:, 1]]
358274
inc = np.where(rel_overs > 0.0)[0]
359275

‎lib/model.py‎

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@ def __init__(self, cfg, arguments, local_rank, distributed):
3131
self.data_loader_train = build_data_loader(cfg, split="train", is_distributed=distributed)
3232
self.data_loader_test = build_data_loader(cfg, split="test", is_distributed=distributed)
3333

34+
logger = logging.getLogger("scene_graph_generation.trainer")
35+
logger.info("Train data size: {}".format(len(self.data_loader_train.dataset)))
36+
logger.info("Test data size: {}".format(len(self.data_loader_test.dataset)))
37+
3438
if not os.path.exists("freq_prior.npy"):
35-
freq_prior = self._get_freq_prior()
36-
np.save("freq_prior.npy", freq_prior)
37-
else:
38-
freq_prior = np.load("freq_prior.npy")
39+
logger.info("Computing frequency prior matrix...")
40+
fg_matrix, bg_matrix = self._get_freq_prior()
41+
prob_matrix = fg_matrix.astype(np.float32)
42+
prob_matrix[:,:,0] = bg_matrix
3943

40-
self.freq_prior = freq_prior
44+
prob_matrix[:,:,0] += 1
45+
prob_matrix /= np.sum(prob_matrix, 2)[:,:,None]
46+
# prob_matrix /= float(fg_matrix.max())
47+
np.save("freq_prior.npy", prob_matrix)
4148

4249
# build scene graph generation model
4350
self.scene_parser = build_scene_parser(cfg); self.scene_parser.to(self.device)
@@ -46,34 +53,59 @@ def __init__(self, cfg, arguments, local_rank, distributed):
4653

4754
self.arguments.update(self.extra_checkpoint_data)
4855

49-
def _get_freq_prior(self):
50-
"""
51-
get the frequency prior for object-pair v.s. predicate
52-
"""
53-
freq_prior = np.zeros((self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES,
54-
self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES,
55-
self.cfg.MODEL.ROI_RELATION_HEAD.NUM_CLASSES))
56-
57-
for i in range(len(self.data_loader_train.dataset)):
58-
target = self.data_loader_train.dataset.get_groundtruth(i)
59-
boxes = target.bbox
60-
overlaps = bbox_overlaps(boxes, boxes)
61-
labels = target.get_field("labels")
62-
pred_labels = target.get_field("pred_labels")
63-
for m in range(pred_labels.size(0)):
64-
for n in range(pred_labels.size(1)):
65-
if pred_labels[m, n] > 0:
66-
label_m = labels[m].item()
67-
label_n = labels[n].item()
68-
freq_prior[label_m, label_n][int(pred_labels[m, n].item())] += 1
69-
else:
70-
if overlaps[m, n] > 0 and m != n:
71-
freq_prior[label_m, label_n][0] += 1
72-
if i % 20 == 0:
73-
print("processing {}/{}".format(i, len(self.data_loader_train.dataset)))
74-
if i >= len(self.data_loader_train.dataset):
75-
break
76-
return freq_prior
56+
def _get_freq_prior(self, must_overlap=False):
57+
58+
fg_matrix = np.zeros((
59+
self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES,
60+
self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES,
61+
self.cfg.MODEL.ROI_RELATION_HEAD.NUM_CLASSES
62+
), dtype=np.int64)
63+
64+
bg_matrix = np.zeros((
65+
self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES,
66+
self.cfg.MODEL.ROI_BOX_HEAD.NUM_CLASSES,
67+
), dtype=np.int64)
68+
69+
for ex_ind in range(len(self.data_loader_train.dataset)):
70+
gt_classes = self.data_loader_train.dataset.gt_classes[ex_ind].copy()
71+
gt_relations = self.data_loader_train.dataset.relationships[ex_ind].copy()
72+
gt_boxes = self.data_loader_train.dataset.gt_boxes[ex_ind].copy()
73+
74+
# For the foreground, we'll just look at everything
75+
o1o2 = gt_classes[gt_relations[:, :2]]
76+
for (o1, o2), gtr in zip(o1o2, gt_relations[:,2]):
77+
fg_matrix[o1, o2, gtr] += 1
78+
79+
# For the background, get all of the things that overlap.
80+
o1o2_total = gt_classes[np.array(
81+
self._box_filter(gt_boxes, must_overlap=must_overlap), dtype=int)]
82+
for (o1, o2) in o1o2_total:
83+
bg_matrix[o1, o2] += 1
84+
85+
if ex_ind % 20 == 0:
86+
print("processing {}/{}".format(ex_ind, len(self.data_loader_train.dataset)))
87+
88+
return fg_matrix, bg_matrix
89+
90+
def _box_filter(self, boxes, must_overlap=False):
91+
""" Only include boxes that overlap as possible relations.
92+
If no overlapping boxes, use all of them."""
93+
n_cands = boxes.shape[0]
94+
95+
overlaps = bbox_overlaps(torch.from_numpy(boxes.astype(np.float)), torch.from_numpy(boxes.astype(np.float))).numpy() > 0
96+
np.fill_diagonal(overlaps, 0)
97+
98+
all_possib = np.ones_like(overlaps, dtype=np.bool)
99+
np.fill_diagonal(all_possib, 0)
100+
101+
if must_overlap:
102+
possible_boxes = np.column_stack(np.where(overlaps))
103+
104+
if possible_boxes.size == 0:
105+
possible_boxes = np.column_stack(np.where(all_possib))
106+
else:
107+
possible_boxes = np.column_stack(np.where(all_possib))
108+
return possible_boxes
77109

78110
def train(self):
79111
"""
@@ -260,7 +292,7 @@ def test(self, timer=None, visualize=False):
260292
predictions=predictions,
261293
output_folder=output_folder,
262294
**extra_args)
263-
295+
264296
if self.cfg.MODEL.RELATION_ON:
265297
eval_sg_results = evaluate_sg(dataset=self.data_loader_test.dataset,
266298
predictions=predictions,

0 commit comments

Comments
 (0)