Skip to content
Snippets Groups Projects
Unverified Commit 59fc5e73 authored by yinchimaoliang's avatar yinchimaoliang Committed by GitHub
Browse files

Add datasets unittest (#24)

* add test_getitem for kitti

* Change dataset to self

* Fix dbinfo path bug

* Change to 230

* Change bboxes and scores

* Debug, check points shape after dbsample

* Debug, add assert gt_bbox and db bbox

* Debug, add assert sampled and gt_bboxes

* Debug, add assertion for dbinfo and gt_box

* Fix transfom_3d unittest

* Fix kitti dataset unittest

* Add kitti data

* Clean debug

* Change data_augment_utils

* Reduce point file

* Change self to dataset

* Change data_augment_utils to normal

* Change order of CLASSES

* Finish test_random_flip_3d

* Add show unittest for scannet and sunrgbd datasets

* Add show unittest for kitti dataset

* Add test_evaluate for kitti dataset

* Add pytest.skip

* Add format_results unittest

* Add test_show for lyft

* Fix eval

* Add test_time_aug unittest

* Add bbox2result_kitti2d unittest

* Change abs to np.isclose
parent a5daf209
No related branches found
No related tags found
No related merge requests found
......@@ -472,8 +472,9 @@ def eval_class(gt_annos,
"""
assert len(gt_annos) == len(dt_annos)
num_examples = len(gt_annos)
if num_examples < num_parts:
num_parts = num_examples
split_parts = get_split_parts(num_examples, num_parts)
rets = calculate_iou_partly(dt_annos, gt_annos, metric, num_parts)
overlaps, parted_overlaps, total_dt_num, total_gt_num = rets
N_SAMPLE_PTS = 41
......
......@@ -110,6 +110,6 @@ class MultiScaleFlipAug3D(object):
repr_str = self.__class__.__name__
repr_str += f'(transforms={self.transforms}, '
repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
repr_str += f'pts_scale_ratio={self.pts_scale_raio}, '
repr_str += f'pts_scale_ratio={self.pts_scale_ratio}, '
repr_str += f'flip_direction={self.flip_direction})'
return repr_str
......@@ -104,9 +104,7 @@ class RandomFlip3D(RandomFlip):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += '(sync_2d={},'.format(self.sync_2d)
repr_str += '(flip_ratio_bev_horizontal={},'.format(
self.flip_ratio_bev_horizontal)
repr_str += '(flip_ratio_bev_vertical={},'.format(
repr_str += 'flip_ratio_bev_vertical={})'.format(
self.flip_ratio_bev_vertical)
return repr_str
......
No preview for this file type
File added
No preview for this file type
import numpy as np
import pytest
import torch
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from mmdet3d.datasets import KittiDataset
def test_getitem():
np.random.seed(0)
data_root = 'tests/data/kitti'
ann_file = 'tests/data/kitti/kitti_infos_train.pkl'
classes = ['Pedestrian', 'Cyclist', 'Car']
pts_prefix = 'velodyne_reduced'
pipeline = [{
'type': 'LoadPointsFromFile',
'load_dim': 4,
'use_dim': 4,
'file_client_args': {
'backend': 'disk'
}
}, {
'type': 'LoadAnnotations3D',
'with_bbox_3d': True,
'with_label_3d': True,
'file_client_args': {
'backend': 'disk'
}
}, {
'type': 'ObjectSample',
'db_sampler': {
'data_root': 'tests/data/kitti/',
'info_path': 'tests/data/kitti/kitti_dbinfos_train.pkl',
'rate': 1.0,
'prepare': {
'filter_by_difficulty': [-1],
'filter_by_min_points': {
'Pedestrian': 10
}
},
'classes': ['Pedestrian', 'Cyclist', 'Car'],
'sample_groups': {
'Pedestrian': 6
}
}
}, {
'type': 'ObjectNoise',
'num_try': 100,
'translation_std': [1.0, 1.0, 0.5],
'global_rot_range': [0.0, 0.0],
'rot_range': [-0.78539816, 0.78539816]
}, {
'type': 'RandomFlip3D',
'flip_ratio_bev_horizontal': 0.5
}, {
'type': 'GlobalRotScaleTrans',
'rot_range': [-0.78539816, 0.78539816],
'scale_ratio_range': [0.95, 1.05]
}, {
'type': 'PointsRangeFilter',
'point_cloud_range': [0, -40, -3, 70.4, 40, 1]
}, {
'type': 'ObjectRangeFilter',
'point_cloud_range': [0, -40, -3, 70.4, 40, 1]
}, {
'type': 'PointShuffle'
}, {
'type': 'DefaultFormatBundle3D',
'class_names': ['Pedestrian', 'Cyclist', 'Car']
}, {
'type': 'Collect3D',
'keys': ['points', 'gt_bboxes_3d', 'gt_labels_3d']
}]
modality = {'use_lidar': True, 'use_camera': False}
split = 'training'
self = KittiDataset(data_root, ann_file, split, pts_prefix, pipeline,
classes, modality)
data = self[0]
points = data['points']._data
gt_bboxes_3d = data['gt_bboxes_3d']._data
gt_labels_3d = data['gt_labels_3d']._data
expected_gt_bboxes_3d = torch.tensor(
[[9.5081, -5.2269, -1.1370, 0.4915, 1.2288, 1.9353, -2.7136]])
expected_gt_labels_3d = torch.tensor([0])
assert points.shape == (780, 4)
assert torch.allclose(
gt_bboxes_3d.tensor, expected_gt_bboxes_3d, atol=1e-4)
assert torch.all(gt_labels_3d == expected_gt_labels_3d)
def test_evaluate():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
data_root = 'tests/data/kitti'
ann_file = 'tests/data/kitti/kitti_infos_train.pkl'
classes = ['Pedestrian', 'Cyclist', 'Car']
pts_prefix = 'velodyne_reduced'
pipeline = [{
'type': 'LoadPointsFromFile',
'load_dim': 4,
'use_dim': 4,
'file_client_args': {
'backend': 'disk'
}
}, {
'type':
'MultiScaleFlipAug3D',
'img_scale': (1333, 800),
'pts_scale_ratio':
1,
'flip':
False,
'transforms': [{
'type': 'GlobalRotScaleTrans',
'rot_range': [0, 0],
'scale_ratio_range': [1.0, 1.0],
'translation_std': [0, 0, 0]
}, {
'type': 'RandomFlip3D'
}, {
'type': 'PointsRangeFilter',
'point_cloud_range': [0, -40, -3, 70.4, 40, 1]
}, {
'type': 'DefaultFormatBundle3D',
'class_names': ['Pedestrian', 'Cyclist', 'Car'],
'with_label': False
}, {
'type': 'Collect3D',
'keys': ['points']
}]
}]
modality = {'use_lidar': True, 'use_camera': False}
split = 'training'
self = KittiDataset(
data_root,
ann_file,
split,
pts_prefix,
pipeline,
classes,
modality,
)
boxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
labels_3d = torch.tensor([
0,
])
scores_3d = torch.tensor([0.5])
metric = ['mAP']
result = dict(boxes_3d=boxes_3d, labels_3d=labels_3d, scores_3d=scores_3d)
ap_dict = self.evaluate([result], metric)
assert np.isclose(ap_dict['KITTI/Overall_3D_easy'], 3.0303030303030307)
assert np.isclose(ap_dict['KITTI/Overall_3D_moderate'], 3.0303030303030307)
assert np.isclose(ap_dict['KITTI/Overall_3D_hard'], 3.0303030303030307)
def test_show():
import mmcv
import tempfile
from os import path as osp
from mmdet3d.core.bbox import LiDARInstance3DBoxes
temp_dir = tempfile.mkdtemp()
data_root = 'tests/data/kitti'
ann_file = 'tests/data/kitti/kitti_infos_train.pkl'
modality = {'use_lidar': True, 'use_camera': False}
split = 'training'
file_client_args = dict(backend='disk')
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
class_names = ['Pedestrian', 'Cyclist', 'Car']
pipeline = [
dict(
type='LoadPointsFromFile',
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter',
point_cloud_range=point_cloud_range),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
]
kitti_dataset = KittiDataset(
data_root, ann_file, split=split, modality=modality, pipeline=pipeline)
boxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[46.1218, -4.6496, -0.9275, 0.5316, 1.4442, 1.7450, 1.1749],
[33.3189, 0.1981, 0.3136, 0.5656, 1.2301, 1.7985, 1.5723],
[46.1366, -4.6404, -0.9510, 0.5162, 1.6501, 1.7540, 1.3778],
[33.2646, 0.2297, 0.3446, 0.5746, 1.3365, 1.7947, 1.5430],
[58.9079, 16.6272, -1.5829, 1.5656, 3.9313, 1.4899, 1.5505]]))
scores_3d = torch.tensor([0.1815, 0.1663, 0.5792, 0.2194, 0.2780])
labels_3d = torch.tensor([0, 0, 1, 1, 2])
result = dict(boxes_3d=boxes_3d, scores_3d=scores_3d, labels_3d=labels_3d)
results = [result]
kitti_dataset.show(results, temp_dir)
pts_file_path = osp.join(temp_dir, '000000', '000000_points.obj')
gt_file_path = osp.join(temp_dir, '000000', '000000_gt.ply')
pred_file_path = osp.join(temp_dir, '000000', '000000_pred.ply')
mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path)
def test_format_results():
from mmdet3d.core.bbox import LiDARInstance3DBoxes
data_root = 'tests/data/kitti'
ann_file = 'tests/data/kitti/kitti_infos_train.pkl'
classes = ['Pedestrian', 'Cyclist', 'Car']
pts_prefix = 'velodyne_reduced'
pipeline = [{
'type': 'LoadPointsFromFile',
'load_dim': 4,
'use_dim': 4,
'file_client_args': {
'backend': 'disk'
}
}, {
'type':
'MultiScaleFlipAug3D',
'img_scale': (1333, 800),
'pts_scale_ratio':
1,
'flip':
False,
'transforms': [{
'type': 'GlobalRotScaleTrans',
'rot_range': [0, 0],
'scale_ratio_range': [1.0, 1.0],
'translation_std': [0, 0, 0]
}, {
'type': 'RandomFlip3D'
}, {
'type': 'PointsRangeFilter',
'point_cloud_range': [0, -40, -3, 70.4, 40, 1]
}, {
'type': 'DefaultFormatBundle3D',
'class_names': ['Pedestrian', 'Cyclist', 'Car'],
'with_label': False
}, {
'type': 'Collect3D',
'keys': ['points']
}]
}]
modality = {'use_lidar': True, 'use_camera': False}
split = 'training'
self = KittiDataset(
data_root,
ann_file,
split,
pts_prefix,
pipeline,
classes,
modality,
)
boxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
labels_3d = torch.tensor([
0,
])
scores_3d = torch.tensor([0.5])
result = dict(boxes_3d=boxes_3d, labels_3d=labels_3d, scores_3d=scores_3d)
results = [result]
result_files, _ = self.format_results(results)
expected_name = np.array(['Pedestrian'])
expected_truncated = np.array([0.])
expected_occluded = np.array([0])
expected_alpha = np.array([-3.3410306])
expected_bbox = np.array([[710.443, 144.00221, 820.29114, 307.58667]])
expected_dimensions = np.array([[1.2, 1.89, 0.48]])
expected_location = np.array([[1.8399826, 1.4700007, 8.410018]])
expected_rotation_y = np.array([-3.1315928])
expected_score = np.array([0.5])
expected_sample_idx = np.array([0])
assert np.all(result_files[0]['name'] == expected_name)
assert np.allclose(result_files[0]['truncated'], expected_truncated)
assert np.all(result_files[0]['occluded'] == expected_occluded)
assert np.allclose(result_files[0]['alpha'], expected_alpha)
assert np.allclose(result_files[0]['bbox'], expected_bbox)
assert np.allclose(result_files[0]['dimensions'], expected_dimensions)
assert np.allclose(result_files[0]['location'], expected_location)
assert np.allclose(result_files[0]['rotation_y'], expected_rotation_y)
assert np.allclose(result_files[0]['score'], expected_score)
assert np.allclose(result_files[0]['sample_idx'], expected_sample_idx)
def test_bbox2result_kitti2d():
data_root = 'tests/data/kitti'
ann_file = 'tests/data/kitti/kitti_infos_train.pkl'
classes = ['Pedestrian', 'Cyclist', 'Car']
pts_prefix = 'velodyne_reduced'
pipeline = [{
'type': 'LoadPointsFromFile',
'load_dim': 4,
'use_dim': 4,
'file_client_args': {
'backend': 'disk'
}
}, {
'type':
'MultiScaleFlipAug3D',
'img_scale': (1333, 800),
'pts_scale_ratio':
1,
'flip':
False,
'transforms': [{
'type': 'GlobalRotScaleTrans',
'rot_range': [-0.1, 0.1],
'scale_ratio_range': [0.9, 1.1],
'translation_std': [0, 0, 0]
}, {
'type': 'RandomFlip3D'
}, {
'type': 'PointsRangeFilter',
'point_cloud_range': [0, -40, -3, 70.4, 40, 1]
}, {
'type': 'DefaultFormatBundle3D',
'class_names': ['Pedestrian', 'Cyclist', 'Car'],
'with_label': False
}, {
'type': 'Collect3D',
'keys': ['points']
}]
}]
modality = {'use_lidar': True, 'use_camera': False}
split = 'training'
self = KittiDataset(
data_root,
ann_file,
split,
pts_prefix,
pipeline,
classes,
modality,
)
bboxes = np.array([[[46.1218, -4.6496, -0.9275, 0.5316, 0.5],
[33.3189, 0.1981, 0.3136, 0.5656, 0.5]],
[[46.1366, -4.6404, -0.9510, 0.5162, 0.5],
[33.2646, 0.2297, 0.3446, 0.5746, 0.5]]])
det_annos = self.bbox2result_kitti2d([bboxes], classes)
expected_name = np.array(
['Pedestrian', 'Pedestrian', 'Cyclist', 'Cyclist'])
expected_bbox = np.array([[46.1218, -4.6496, -0.9275, 0.5316],
[33.3189, 0.1981, 0.3136, 0.5656],
[46.1366, -4.6404, -0.951, 0.5162],
[33.2646, 0.2297, 0.3446, 0.5746]])
expected_score = np.array([0.5, 0.5, 0.5, 0.5])
assert np.all(det_annos[0]['name'] == expected_name)
assert np.allclose(det_annos[0]['bbox'], expected_bbox)
assert np.allclose(det_annos[0]['score'], expected_score)
......@@ -162,3 +162,50 @@ def test_evaluate():
assert abs(ret_dict['window_AP_0.25'] - 1.0) < 0.01
assert abs(ret_dict['counter_AP_0.25'] - 1.0) < 0.01
assert abs(ret_dict['curtain_AP_0.25'] - 1.0) < 0.01
def test_show():
import mmcv
import tempfile
from os import path as osp
from mmdet3d.core.bbox import DepthInstance3DBoxes
temp_dir = tempfile.mkdtemp()
root_path = './tests/data/scannet'
ann_file = './tests/data/scannet/scannet_infos.pkl'
scannet_dataset = ScanNetDataset(root_path, ann_file)
boxes_3d = DepthInstance3DBoxes(
torch.tensor([[
-2.4053e+00, 9.2295e-01, 8.0661e-02, 2.4054e+00, 2.1468e+00,
8.5990e-01, 0.0000e+00
],
[
-1.9341e+00, -2.0741e+00, 3.0698e-03, 3.2206e-01,
2.5322e-01, 3.5144e-01, 0.0000e+00
],
[
-3.6908e+00, 8.0684e-03, 2.6201e-01, 4.1515e-01,
7.6489e-01, 5.3585e-01, 0.0000e+00
],
[
2.6332e+00, 8.5143e-01, -4.9964e-03, 3.0367e-01,
1.3448e+00, 1.8329e+00, 0.0000e+00
],
[
2.0221e-02, 2.6153e+00, 1.5109e-02, 7.3335e-01,
1.0429e+00, 1.0251e+00, 0.0000e+00
]]))
scores_3d = torch.tensor(
[1.2058e-04, 2.3012e-03, 6.2324e-06, 6.6139e-06, 6.7965e-05])
labels_3d = torch.tensor([0, 0, 0, 0, 0])
result = dict(boxes_3d=boxes_3d, scores_3d=scores_3d, labels_3d=labels_3d)
results = [result]
scannet_dataset.show(results, temp_dir)
pts_file_path = osp.join(temp_dir, 'scene0000_00',
'scene0000_00_points.obj')
gt_file_path = osp.join(temp_dir, 'scene0000_00', 'scene0000_00_gt.ply')
pred_file_path = osp.join(temp_dir, 'scene0000_00',
'scene0000_00_pred.ply')
mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path)
......@@ -120,3 +120,34 @@ def test_evaluate():
assert abs(bed_precision_25 - 1) < 0.01
assert abs(dresser_precision_25 - 1) < 0.01
assert abs(night_stand_precision_25 - 1) < 0.01
def test_show():
import mmcv
import tempfile
from os import path as osp
from mmdet3d.core.bbox import DepthInstance3DBoxes
temp_dir = tempfile.mkdtemp()
root_path = './tests/data/sunrgbd'
ann_file = './tests/data/sunrgbd/sunrgbd_infos.pkl'
sunrgbd_dataset = SUNRGBDDataset(root_path, ann_file)
boxes_3d = DepthInstance3DBoxes(
torch.tensor(
[[1.1500, 4.2614, -1.0669, 1.3219, 2.1593, 1.0267, 1.6473],
[-0.9583, 2.1916, -1.0881, 0.6213, 1.3022, 1.6275, -3.0720],
[2.5697, 4.8152, -1.1157, 0.5421, 0.7019, 0.7896, 1.6712],
[0.7283, 2.5448, -1.0356, 0.7691, 0.9056, 0.5771, 1.7121],
[-0.9860, 3.2413, -1.2349, 0.5110, 0.9940, 1.1245, 0.3295]]))
scores_3d = torch.tensor(
[1.5280e-01, 1.6682e-03, 6.2811e-04, 1.2860e-03, 9.4229e-06])
labels_3d = torch.tensor([0, 0, 0, 0, 0])
result = dict(boxes_3d=boxes_3d, scores_3d=scores_3d, labels_3d=labels_3d)
results = [result]
sunrgbd_dataset.show(results, temp_dir)
pts_file_path = osp.join(temp_dir, '000001', '000001_points.obj')
gt_file_path = osp.join(temp_dir, '000001', '000001_gt.ply')
pred_file_path = osp.join(temp_dir, '000001', '000001_pred.ply')
mmcv.check_file_exist(pts_file_path)
mmcv.check_file_exist(gt_file_path)
mmcv.check_file_exist(pred_file_path)
import numpy as np
import torch
from mmdet3d.datasets.pipelines import MultiScaleFlipAug3D
def test_multi_scale_flip_aug_3D():
np.random.seed(0)
transforms = [{
'type': 'GlobalRotScaleTrans',
'rot_range': [-0.1, 0.1],
'scale_ratio_range': [0.9, 1.1],
'translation_std': [0, 0, 0]
}, {
'type': 'RandomFlip3D',
'sync_2d': False,
'flip_ratio_bev_horizontal': 0.5
}, {
'type': 'IndoorPointSample',
'num_points': 5
}, {
'type':
'DefaultFormatBundle3D',
'class_names': ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk',
'dresser', 'night_stand', 'bookshelf', 'bathtub'),
'with_label':
False
}, {
'type': 'Collect3D',
'keys': ['points']
}]
img_scale = (1333, 800)
pts_scale_ratio = 1
multi_scale_flip_aug_3D = MultiScaleFlipAug3D(transforms, img_scale,
pts_scale_ratio)
pts_file_name = 'tests/data/sunrgbd/points/000001.bin'
sample_idx = 4
file_name = 'tests/data/sunrgbd/points/000001.bin'
bbox3d_fields = []
points = np.array([[0.20397437, 1.4267826, -1.0503972, 0.16195858],
[-2.2095256, 3.3159535, -0.7706928, 0.4416629],
[1.5090443, 3.2764456, -1.1913797, 0.02097607],
[-1.373904, 3.8711405, 0.8524302, 2.064786],
[-1.8139812, 3.538856, -1.0056694, 0.20668638]])
results = dict(
points=points,
pts_file_name=pts_file_name,
sample_idx=sample_idx,
file_name=file_name,
bbox3d_fields=bbox3d_fields)
results = multi_scale_flip_aug_3D(results)
expected_points = torch.tensor(
[[-2.2095, 3.3160, -0.7707, 0.4417], [-1.3739, 3.8711, 0.8524, 2.0648],
[-1.8140, 3.5389, -1.0057, 0.2067], [0.2040, 1.4268, -1.0504, 0.1620],
[1.5090, 3.2764, -1.1914, 0.0210]],
dtype=torch.float64)
assert torch.allclose(
results['points'][0]._data, expected_points, atol=1e-4)
......@@ -3,7 +3,7 @@ import numpy as np
import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.datasets import ObjectNoise, ObjectSample
from mmdet3d.datasets import ObjectNoise, ObjectSample, RandomFlip3D
def test_remove_points_in_boxes():
......@@ -35,7 +35,6 @@ def test_remove_points_in_boxes():
def test_object_sample():
import pickle
db_sampler = mmcv.ConfigDict({
'data_root': './tests/data/kitti/',
'info_path': './tests/data/kitti/kitti_dbinfos_train.pkl',
......@@ -51,8 +50,6 @@ def test_object_sample():
'Pedestrian': 6
}
})
with open('./tests/data/kitti/kitti_dbinfos_train.pkl', 'rb') as f:
db_infos = pickle.load(f)
np.random.seed(0)
object_sample = ObjectSample(db_sampler)
points = np.fromfile(
......@@ -60,11 +57,19 @@ def test_object_sample():
np.float32).reshape(-1, 4)
annos = mmcv.load('./tests/data/kitti/kitti_infos_train.pkl')
info = annos[0]
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
annos = info['annos']
loc = annos['location']
dims = annos['dimensions']
rots = annos['rotation_y']
gt_names = annos['name']
gt_bboxes_3d = db_infos['Pedestrian'][0]['box3d_lidar']
gt_bboxes_3d = LiDARInstance3DBoxes([gt_bboxes_3d])
CLASSES = ('Car', 'Pedestrian', 'Cyclist')
gt_bboxes_3d = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1).astype(np.float32)
gt_bboxes_3d = CameraInstance3DBoxes(gt_bboxes_3d).convert_to(
Box3DMode.LIDAR, np.linalg.inv(rect @ Trv2c))
CLASSES = ('Pedestrian', 'Cyclist', 'Car')
gt_labels = []
for cat in gt_names:
if cat in CLASSES:
......@@ -87,9 +92,9 @@ def test_object_sample():
'classes=[\'Pedestrian\', \'Cyclist\', \'Car\'], ' \
'sample_groups={\'Pedestrian\': 6}'
assert repr_str == expected_repr_str
assert points.shape == (1177, 4)
assert gt_bboxes_3d.tensor.shape == (2, 7)
assert np.all(gt_labels_3d == [1, 0])
assert points.shape == (800, 4)
assert gt_bboxes_3d.tensor.shape == (1, 7)
assert np.all(gt_labels_3d == [0])
def test_object_noise():
......@@ -125,3 +130,59 @@ def test_object_noise():
assert repr_str == expected_repr_str
assert points.shape == (800, 4)
assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d, 1e-3)
def test_random_flip_3d():
random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
points = np.array([[22.7035, 9.3901, -0.2848, 0.0000],
[21.9826, 9.1766, -0.2698, 0.0000],
[21.4329, 9.0209, -0.2578, 0.0000],
[21.3068, 9.0205, -0.2558, 0.0000],
[21.3400, 9.1305, -0.2578, 0.0000],
[21.3291, 9.2099, -0.2588, 0.0000],
[21.2759, 9.2599, -0.2578, 0.0000],
[21.2686, 9.2982, -0.2588, 0.0000],
[21.2334, 9.3607, -0.2588, 0.0000],
[21.2179, 9.4372, -0.2598, 0.0000]])
bbox3d_fields = ['gt_bboxes_3d']
img_fields = []
box_type_3d = LiDARInstance3DBoxes
gt_bboxes_3d = LiDARInstance3DBoxes(
torch.tensor(
[[38.9229, 18.4417, -1.1459, 0.7100, 1.7600, 1.8600, -2.2652],
[12.7768, 0.5795, -2.2682, 0.5700, 0.9900, 1.7200, -2.5029],
[12.7557, 2.2996, -1.4869, 0.6100, 1.1100, 1.9000, -1.9390],
[10.6677, 0.8064, -1.5435, 0.7900, 0.9600, 1.7900, 1.0856],
[5.0903, 5.1004, -1.2694, 0.7100, 1.7000, 1.8300, -1.9136]]))
input_dict = dict(
points=points,
bbox3d_fields=bbox3d_fields,
box_type_3d=box_type_3d,
img_fields=img_fields,
gt_bboxes_3d=gt_bboxes_3d)
input_dict = random_flip_3d(input_dict)
points = input_dict['points']
gt_bboxes_3d = input_dict['gt_bboxes_3d'].tensor
expected_points = np.array([[22.7035, -9.3901, -0.2848, 0.0000],
[21.9826, -9.1766, -0.2698, 0.0000],
[21.4329, -9.0209, -0.2578, 0.0000],
[21.3068, -9.0205, -0.2558, 0.0000],
[21.3400, -9.1305, -0.2578, 0.0000],
[21.3291, -9.2099, -0.2588, 0.0000],
[21.2759, -9.2599, -0.2578, 0.0000],
[21.2686, -9.2982, -0.2588, 0.0000],
[21.2334, -9.3607, -0.2588, 0.0000],
[21.2179, -9.4372, -0.2598, 0.0000]])
expected_gt_bboxes_3d = torch.tensor(
[[38.9229, -18.4417, -1.1459, 0.7100, 1.7600, 1.8600, 5.4068],
[12.7768, -0.5795, -2.2682, 0.5700, 0.9900, 1.7200, 5.6445],
[12.7557, -2.2996, -1.4869, 0.6100, 1.1100, 1.9000, 5.0806],
[10.6677, -0.8064, -1.5435, 0.7900, 0.9600, 1.7900, 2.0560],
[5.0903, -5.1004, -1.2694, 0.7100, 1.7000, 1.8300, 5.0552]])
repr_str = repr(random_flip_3d)
expected_repr_str = 'RandomFlip3D(sync_2d=True,' \
'flip_ratio_bev_vertical=1.0)'
assert np.allclose(points, expected_points)
assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d)
assert repr_str == expected_repr_str
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment