Skip to content
Snippets Groups Projects
Commit 31b3505d authored by wuyuefeng's avatar wuyuefeng Committed by zhangwenwei
Browse files

Core unittest

parent ce4f66b6
No related branches found
No related tags found
No related merge requests found
......@@ -129,6 +129,76 @@ We compare the training speed (samples/s) with other codebases if they implement
python train.py --dataset sunrgbd --batch_size 16
```
Then benchmark the test speed by running
```bash
python eval.py --dataset sunrgbd --checkpoint_path log_sunrgbd/checkpoint.tar --batch_size 1 --dump_dir eval_sunrgbd --cluster_sampling seed_fps --use_3d_nms --use_cls_nms --per_class_proposal
```
Note that eval.py is modified to compute inference time.
<details>
<summary>
(diff to benchmark the similar models - click to expand)
</summary>
```diff
diff --git a/eval.py b/eval.py
index c0b2886..04921e9 100644
--- a/eval.py
+++ b/eval.py
@@ -10,6 +10,7 @@ import os
import sys
import numpy as np
from datetime import datetime
+import time
import argparse
import importlib
import torch
@@ -28,7 +29,7 @@ parser.add_argument('--checkpoint_path', default=None, help='Model checkpoint pa
parser.add_argument('--dump_dir', default=None, help='Dump dir to save sample outputs [default: None]')
parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]')
parser.add_argument('--num_target', type=int, default=256, help='Point Number [default: 256]')
-parser.add_argument('--batch_size', type=int, default=8, help='Batch Size during training [default: 8]')
+parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during training [default: 8]')
parser.add_argument('--vote_factor', type=int, default=1, help='Number of votes generated from each seed [default: 1]')
parser.add_argument('--cluster_sampling', default='vote_fps', help='Sampling strategy for vote clusters: vote_fps, seed_fps, random [default: vote_fps]')
parser.add_argument('--ap_iou_thresholds', default='0.25,0.5', help='A list of AP IoU thresholds [default: 0.25,0.5]')
@@ -132,6 +133,7 @@ CONFIG_DICT = {'remove_empty_box': (not FLAGS.faster_eval), 'use_3d_nms': FLAGS.
# ------------------------------------------------------------------------- GLOBAL CONFIG END
def evaluate_one_epoch():
+ time_list = list()
stat_dict = {}
ap_calculator_list = [APCalculator(iou_thresh, DATASET_CONFIG.class2type) \
for iou_thresh in AP_IOU_THRESHOLDS]
@@ -144,6 +146,8 @@ def evaluate_one_epoch():
# Forward pass
inputs = {'point_clouds': batch_data_label['point_clouds']}
+ torch.cuda.synchronize()
+ start_time = time.perf_counter()
with torch.no_grad():
end_points = net(inputs)
@@ -161,6 +165,12 @@ def evaluate_one_epoch():
batch_pred_map_cls = parse_predictions(end_points, CONFIG_DICT)
batch_gt_map_cls = parse_groundtruths(end_points, CONFIG_DICT)
+ torch.cuda.synchronize()
+ elapsed = time.perf_counter() - start_time
+ time_list.append(elapsed)
+
+ if len(time_list==200):
+ print("average inference time: %4f"%(sum(time_list[5:])/len(time_list[5:])))
for ap_calculator in ap_calculator_list:
ap_calculator.step(batch_pred_map_cls, batch_gt_map_cls)
```
### PointPillars-car
* __MMDetection3D__: With release v0.1.0, run
......
......@@ -6,6 +6,18 @@ import numpy as np
def camera_to_lidar(points, r_rect, velo2cam):
"""Convert points in camera coordinate to lidar coordinate.
Args:
points (np.ndarray, shape=[N, 3]): Points in camera coordinate.
r_rect (np.ndarray, shape=[4, 4]): Matrix to project points in
specific camera coordinate (e.g. CAM2) to CAM0.
velo2cam (np.ndarray, shape=[4, 4]): Matrix to project points in
camera coordinate to lidar coordinate.
Returns:
np.ndarray, shape=[N, 3]: Points in lidar coordinate.
"""
points_shape = list(points.shape[0:-1])
if points.shape[-1] == 3:
points = np.concatenate([points, np.ones(points_shape + [1])], axis=-1)
......@@ -14,6 +26,18 @@ def camera_to_lidar(points, r_rect, velo2cam):
def box_camera_to_lidar(data, r_rect, velo2cam):
"""Covert boxes in camera coordinate to lidar coordinate.
Args:
data (np.ndarray, shape=[N, 7]): Boxes in camera coordinate.
r_rect (np.ndarray, shape=[4, 4]): Matrix to project points in
specific camera coordinate (e.g. CAM2) to CAM0.
velo2cam (np.ndarray, shape=[4, 4]): Matrix to project points in
camera coordinate to lidar coordinate.
Returns:
np.ndarray, shape=[N, 3]: Boxes in lidar coordinate.
"""
xyz = data[:, 0:3]
l, h, w = data[:, 3:4], data[:, 4:5], data[:, 5:6]
r = data[:, 6:7]
......@@ -96,6 +120,16 @@ def center_to_corner_box2d(centers, dims, angles=None, origin=0.5):
@numba.jit(nopython=True)
def depth_to_points(depth, trunc_pixel):
"""Convert depth map to points.
Args:
depth (np.array, shape=[H, W]): Depth map which
the row of [0~`trunc_pixel`] are truncated.
trunc_pixel (int): The number of truncated row.
Returns:
np.ndarray: Points in camera coordinates.
"""
num_pts = np.sum(depth[trunc_pixel:, ] > 0.1)
points = np.zeros((num_pts, 3), dtype=depth.dtype)
x = np.array([0, 0, 1], dtype=depth.dtype)
......@@ -110,6 +144,21 @@ def depth_to_points(depth, trunc_pixel):
def depth_to_lidar_points(depth, trunc_pixel, P2, r_rect, velo2cam):
"""Convert depth map to points in lidar coordinate.
Args:
depth (np.array, shape=[H, W]): Depth map which
the row of [0~`trunc_pixel`] are truncated.
trunc_pixel (int): The number of truncated row.
P2 (p.array, shape=[4, 4]): Intrinsics of Camera2.
r_rect (np.ndarray, shape=[4, 4]): Matrix to project points in
specific camera coordinate (e.g. CAM2) to CAM0.
velo2cam (np.ndarray, shape=[4, 4]): Matrix to project points in
camera coordinate to lidar coordinate.
Returns:
np.ndarray: Points in lidar coordinates.
"""
pts = depth_to_points(depth, trunc_pixel)
points_shape = list(pts.shape[0:-1])
points = np.concatenate([pts, np.ones(points_shape + [1])], axis=-1)
......@@ -119,6 +168,16 @@ def depth_to_lidar_points(depth, trunc_pixel, P2, r_rect, velo2cam):
def rotation_3d_in_axis(points, angles, axis=0):
"""Rotate points in specific axis.
Args:
points (np.ndarray, shape=[N, point_size, 3]]):
angles (np.ndarray, shape=[N]]):
axis (int): Axis to rotate at.
Returns:
np.ndarray: Rotated points.
"""
# points: [N, point_size, 3]
rot_sin = np.sin(angles)
rot_cos = np.cos(angles)
......@@ -170,6 +229,14 @@ def center_to_corner_box3d(centers,
@numba.jit(nopython=True)
def box2d_to_corner_jit(boxes):
"""Convert box2d to corner.
Args:
boxes (np.ndarray, shape=[N, 5]): Boxes2d with rotation.
Returns:
box_corners (np.ndarray, shape=[N, 4, 2]): Box corners.
"""
num_box = boxes.shape[0]
corners_norm = np.zeros((4, 2), dtype=boxes.dtype)
corners_norm[1, 1] = 1.0
......@@ -193,6 +260,14 @@ def box2d_to_corner_jit(boxes):
@numba.njit
def corner_to_standup_nd_jit(boxes_corner):
"""Convert boxes_corner to aligned (min-max) boxes.
Args:
boxes_corner (np.ndarray, shape=[N, 2**dim, dim]): Boxes corners.
Returns:
np.ndarray, shape=[N, dim*2]: Aligned (min-max) boxes.
"""
num_boxes = boxes_corner.shape[0]
ndim = boxes_corner.shape[-1]
result = np.zeros((num_boxes, ndim * 2), dtype=boxes_corner.dtype)
......@@ -229,6 +304,16 @@ def corner_to_surfaces_3d_jit(corners):
def rotation_points_single_angle(points, angle, axis=0):
"""Rotate points with a single angle.
Args:
points (np.ndarray, shape=[N, 3]]):
angles (np.ndarray, shape=[1]]):
axis (int): Axis to rotate at.
Returns:
np.ndarray: Rotated points.
"""
# points: [N, 3]
rot_sin = np.sin(angle)
rot_cos = np.cos(angle)
......@@ -251,6 +336,15 @@ def rotation_points_single_angle(points, angle, axis=0):
def points_cam2img(points_3d, proj_mat):
"""Project points in camera coordinates to image coordinates.
Args:
points_3d (np.ndarray): Points in shape (N, 3)
proj_mat (np.ndarray): Transformation matrix between coordinates.
Returns:
np.ndarray: Points in image coordinates with shape [N, 2].
"""
points_shape = list(points_3d.shape)
points_shape[-1] = 1
points_4 = np.concatenate([points_3d, np.zeros(points_shape)], axis=-1)
......@@ -259,7 +353,16 @@ def points_cam2img(points_3d, proj_mat):
return point_2d_res
def box3d_to_bbox(box3d, rect, Trv2c, P2):
def box3d_to_bbox(box3d, P2):
"""Convert box3d in camera coordinates to bbox in image coordinates.
Args:
box3d (np.ndarray, shape=[N, 7]): Boxes in camera coordinate.
P2 (np.array, shape=[4, 4]): Intrinsics of Camera2.
Returns:
np.ndarray, shape=[N, 4]: Boxes 2d in image coordinates.
"""
box_corners = center_to_corner_box3d(
box3d[:, :3], box3d[:, 3:6], box3d[:, 6], [0.5, 1.0, 0.5], axis=1)
box_corners_in_image = points_cam2img(box_corners, P2)
......@@ -293,6 +396,17 @@ def corner_to_surfaces_3d(corners):
def points_in_rbbox(points, rbbox, z_axis=2, origin=(0.5, 0.5, 0)):
"""Check points in rotated bbox and return indicces.
Args:
points (np.ndarray, shape=[N, 3+dim]): Points to query.
rbbox (np.ndarray, shape=[M, 7]): Boxes3d with rotation.
z_axis (int): Indicate which axis is height.
origin (tuple[int]): Indicate the position of box center.
Returns:
np.ndarray, shape=[N, M]: Indices of points in each box.
"""
# TODO: this function is different from PointCloud3D, be careful
# when start to use nuscene, check the input
rbbox_corners = center_to_corner_box3d(
......@@ -303,6 +417,14 @@ def points_in_rbbox(points, rbbox, z_axis=2, origin=(0.5, 0.5, 0)):
def minmax_to_corner_2d(minmax_box):
"""Convert minmax box to corners2d.
Args:
minmax_box (np.ndarray, shape=[N, dims]): minmax boxes.
Returns:
np.ndarray: 2d corners of boxes
"""
ndim = minmax_box.shape[-1] // 2
center = minmax_box[..., :ndim]
dims = minmax_box[..., ndim:] - center
......@@ -310,6 +432,18 @@ def minmax_to_corner_2d(minmax_box):
def limit_period(val, offset=0.5, period=np.pi):
"""Limit the value into a period for periodic function.
Args:
val (np.ndarray): The value to be converted.
offset (float, optional): Offset to set the value range. \
Defaults to 0.5.
period (float, optional): Period of the value. Defaults to np.pi.
Returns:
torch.Tensor: Value in the range of \
[-offset * period, (1-offset) * period]
"""
return val - np.floor(val / period + offset) * period
......@@ -318,7 +452,8 @@ def create_anchors_3d_range(feature_size,
sizes=((1.6, 3.9, 1.56), ),
rotations=(0, np.pi / 2),
dtype=np.float32):
"""
"""Create anchors 3d by range.
Args:
feature_size (list[float] | tuple[float]): Feature map size. It is
either a list of a tuple of [D, H, W](in order of z, y, and x).
......@@ -360,13 +495,20 @@ def create_anchors_3d_range(feature_size,
return np.transpose(ret, [2, 1, 0, 3, 4, 5])
def center_to_minmax_2d_0_5(centers, dims):
return np.concatenate([centers - dims / 2, centers + dims / 2], axis=-1)
def center_to_minmax_2d(centers, dims, origin=0.5):
"""Center to minmax.
Args:
centers (np.ndarray): Center points.
dims (np.ndarray): Dimensions.
origin (list or array or float): origin point relate to smallest point.
def center_to_minmax_2d(centers, dims, origin=0.5):
Returns:
np.ndarray: Minmax points.
"""
if origin == 0.5:
return center_to_minmax_2d_0_5(centers, dims)
return np.concatenate([centers - dims / 2, centers + dims / 2],
axis=-1)
corners = center_to_corner_box2d(centers, dims, origin=origin)
return corners[:, [0, 2]].reshape([-1, 4])
......@@ -429,16 +571,20 @@ def iou_jit(boxes, query_boxes, mode='iou', eps=0.0):
return overlaps
def change_box3d_center_(box3d, src, dst):
dst = np.array(dst, dtype=box3d.dtype)
src = np.array(src, dtype=box3d.dtype)
box3d[..., :3] += box3d[..., 3:6] * (dst - src)
def projection_matrix_to_CRT_kitti(proj):
"""Split projection matrix of kitti.
P = C @ [R|T]
C is upper triangular matrix, so we need to inverse CR and use QR
stable for all kitti camera projection matrix.
Args:
proj (p.array, shape=[4, 4]): Intrinsics of camera.
Returns:
tuple[np.ndarray]: Splited matrix of C, R and T.
"""
def projection_matrix_to_CRT_kitti(proj):
# P = C @ [R|T]
# C is upper triangular matrix, so we need to inverse CR and use QR
# stable for all kitti camera projection matrix
CR = proj[0:3, 0:3]
CT = proj[0:3, 3]
RinvCinv = np.linalg.inv(CR)
......@@ -450,6 +596,20 @@ def projection_matrix_to_CRT_kitti(proj):
def remove_outside_points(points, rect, Trv2c, P2, image_shape):
"""Remove points which are outside of image.
Args:
points (np.ndarray, shape=[N, 3+dims]): Total points.
rect (np.ndarray, shape=[4, 4]): Matrix to project points in
specific camera coordinate (e.g. CAM2) to CAM0.
Trv2c (np.ndarray, shape=[4, 4]): Matrix to project points in
camera coordinate to lidar coordinate.
P2 (p.array, shape=[4, 4]): Intrinsics of Camera2.
image_shape (list[int]): Shape of image.
Returns:
np.ndarray, shape=[N, 3+dims]: Filtered points.
"""
# 5x faster than remove_outside_points_v1(2ms vs 10ms)
C, R, T = projection_matrix_to_CRT_kitti(P2)
image_bbox = [0, 0, image_shape[1], image_shape[0]]
......@@ -464,6 +624,17 @@ def remove_outside_points(points, rect, Trv2c, P2, image_shape):
def get_frustum(bbox_image, C, near_clip=0.001, far_clip=100):
"""Get frustum corners in camera coordinates.
Args:
bbox_image (list[int]): box in image coordinates.
C (np.ndarray): Intrinsics.
near_clip (float): Nearest distance of frustum.
far_clip (float): Farthest distance of frustum.
Returns:
np.ndarray, shape=[8, 3]: coordinates of frustum corners.
"""
fku = C[0, 0]
fkv = -C[1, 1]
u0v0 = C[0:2, 2]
......@@ -484,6 +655,17 @@ def get_frustum(bbox_image, C, near_clip=0.001, far_clip=100):
def surface_equ_3d(polygon_surfaces):
"""
Args:
polygon_surfaces (np.ndarray): Polygon surfaces with shape of
[num_polygon, max_num_surfaces, max_num_points_of_surface, 3].
All surfaces' normal vector must direct to internal.
Max_num_points_of_surface must at least 3.
Returns:
tuple: normal vector and its direction.
"""
# return [a, b, c], d in ax+by+cz+d=0
# polygon_surfaces: [num_polygon, num_surfaces, num_points_of_polygon, 3]
surface_vec = polygon_surfaces[:, :, :2, :] - \
......@@ -499,6 +681,21 @@ def surface_equ_3d(polygon_surfaces):
@numba.njit
def _points_in_convex_polygon_3d_jit(points, polygon_surfaces, normal_vec, d,
num_surfaces):
"""
Args:
points (np.ndarray): Input points with shape of (num_points, 3).
polygon_surfaces (np.ndarray): Polygon surfaces with shape of
(num_polygon, max_num_surfaces, max_num_points_of_surface, 3).
All surfaces' normal vector must direct to internal.
Max_num_points_of_surface must at least 3.
normal_vec (np.ndarray): Normal vector of polygon_surfaces.
d (int): Directions of normal vector.
num_surfaces (np.ndarray): Number of surfaces a polygon contains
shape of (num_polygon).
Returns:
np.ndarray: Result matrix with the shape of [num_points, num_polygon].
"""
max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
num_points = points.shape[0]
num_polygons = polygon_surfaces.shape[0]
......
......@@ -640,6 +640,18 @@ def kitti_eval(gt_annos,
dt_annos,
current_classes,
eval_types=['bbox', 'bev', '3d']):
"""KITTI evaluation.
Args:
gt_annos (list[dict]): Contain gt information of each sample.
dt_annos (list[dict]): Contain detected information of each sample.
current_classes (list[str]): Classes to evaluation.
eval_types (list[str], optional): Types to eval.
Defaults to ['bbox', 'bev', '3d'].
Returns:
tuple: String and dict of evaluation results.
"""
assert 'bbox' in eval_types, 'must evaluate bbox at least'
overlap_0_7 = np.array([[0.7, 0.5, 0.5, 0.7,
0.5], [0.7, 0.5, 0.5, 0.7, 0.5],
......@@ -749,6 +761,16 @@ def kitti_eval(gt_annos,
def kitti_eval_coco_style(gt_annos, dt_annos, current_classes):
"""coco style evaluation of kitti.
Args:
gt_annos (list[dict]): Contain gt information of each sample.
dt_annos (list[dict]): Contain detected information of each sample.
current_classes (list[str]): Classes to evaluation.
Returns:
string: Evaluation results.
"""
class_to_name = {
0: 'Car',
1: 'Pedestrian',
......
......@@ -229,6 +229,15 @@ def rbbox_to_corners(corners, rbbox):
@cuda.jit('(float32[:], float32[:])', device=True, inline=True)
def inter(rbbox1, rbbox2):
"""Compute intersection of two rotated boxes.
Args:
rbox1 (np.ndarray, shape=[5]): Rotated 2d box.
rbox2 (np.ndarray, shape=[5]): Rotated 2d box.
Returns:
float: Intersection of two rotated boxes.
"""
corners1 = cuda.local.array((8, ), dtype=numba.float32)
corners2 = cuda.local.array((8, ), dtype=numba.float32)
intersection_corners = cuda.local.array((16, ), dtype=numba.float32)
......@@ -246,6 +255,19 @@ def inter(rbbox1, rbbox2):
@cuda.jit('(float32[:], float32[:], int32)', device=True, inline=True)
def devRotateIoUEval(rbox1, rbox2, criterion=-1):
"""Compute rotated iou on device.
Args:
rbox1 (np.ndarray, shape=[5]): Rotated 2d box.
rbox2 (np.ndarray, shape=[5]): Rotated 2d box.
criterion (int, optional): Indicate different type of iou.
-1 indicate `area_inter / (area1 + area2 - area_inter)`,
0 indicate `area_inter / area1`,
1 indicate `area_inter / area2`.
Returns:
float: iou between two input boxes.
"""
area1 = rbox1[2] * rbox1[3]
area2 = rbox2[2] * rbox2[3]
area_inter = inter(rbox1, rbox2)
......@@ -268,6 +290,19 @@ def rotate_iou_kernel_eval(N,
dev_query_boxes,
dev_iou,
criterion=-1):
"""Kernel of computing rotated iou.
Args:
N (int): The number of boxes.
K (int): The number of query boxes.
dev_boxes (np.ndarray): Boxes on device.
dev_query_boxes (np.ndarray): Query boxes on device.
dev_iou (np.ndarray): Computed iou to return.
criterion (int, optional): Indicate different type of iou.
-1 indicate `area_inter / (area1 + area2 - area_inter)`,
0 indicate `area_inter / area1`,
1 indicate `area_inter / area2`.
"""
threadsPerBlock = 8 * 8
row_start = cuda.blockIdx.x
col_start = cuda.blockIdx.y
......@@ -310,8 +345,12 @@ def rotate_iou_gpu_eval(boxes, query_boxes, criterion=-1, device_id=0):
Args:
boxes (torch.Tensor): rbboxes. format: centers, dims,
angles(clockwise when positive) with the shape of [N, 5].
query_boxes (float tensor: [K, 5]): [description]
device_id (int, optional): Defaults to 0. [description]
query_boxes (float tensor: [K, 5]): rbboxes to compute iou with boxes.
device_id (int, optional): Defaults to 0. Device to use.
criterion (int, optional): Indicate different type of iou.
-1 indicate `area_inter / (area1 + area2 - area_inter)`,
0 indicate `area_inter / area1`,
1 indicate `area_inter / area2`.
Returns:
np.ndarray: IoU results.
......
import numpy as np
def test_camera_to_lidar():
from mmdet3d.core.bbox.box_np_ops import camera_to_lidar
points = np.array([[1.84, 1.47, 8.41]])
rect = np.array([[0.9999128, 0.01009263, -0.00851193, 0.],
[-0.01012729, 0.9999406, -0.00403767, 0.],
[0.00847068, 0.00412352, 0.9999556, 0.], [0., 0., 0.,
1.]])
Trv2c = np.array([[0.00692796, -0.9999722, -0.00275783, -0.02457729],
[-0.00116298, 0.00274984, -0.9999955, -0.06127237],
[0.9999753, 0.00693114, -0.0011439, -0.3321029],
[0., 0., 0., 1.]])
points_lidar = camera_to_lidar(points, rect, Trv2c)
expected_points = np.array([[8.73138192, -1.85591746, -1.59969933]])
assert np.allclose(points_lidar, expected_points)
def test_box_camera_to_lidar():
from mmdet3d.core.bbox.box_np_ops import box_camera_to_lidar
box = np.array([[1.84, 1.47, 8.41, 1.2, 1.89, 0.48, 0.01]])
rect = np.array([[0.9999128, 0.01009263, -0.00851193, 0.],
[-0.01012729, 0.9999406, -0.00403767, 0.],
[0.00847068, 0.00412352, 0.9999556, 0.], [0., 0., 0.,
1.]])
Trv2c = np.array([[0.00692796, -0.9999722, -0.00275783, -0.02457729],
[-0.00116298, 0.00274984, -0.9999955, -0.06127237],
[0.9999753, 0.00693114, -0.0011439, -0.3321029],
[0., 0., 0., 1.]])
box_lidar = box_camera_to_lidar(box, rect, Trv2c)
expected_box = np.array(
[[8.73138192, -1.85591746, -1.59969933, 0.48, 1.2, 1.89, 0.01]])
assert np.allclose(box_lidar, expected_box)
def test_corners_nd():
from mmdet3d.core.bbox.box_np_ops import corners_nd
dims = np.array([[0.47, 0.98]])
corners = corners_nd(dims)
expected_corners = np.array([[[-0.235, -0.49], [-0.235, 0.49],
[0.235, 0.49], [0.235, -0.49]]])
assert np.allclose(corners, expected_corners)
def test_center_to_corner_box2d():
from mmdet3d.core.bbox.box_np_ops import center_to_corner_box2d
center = np.array([[9.348705, -3.6271024]])
dims = np.array([[0.47, 0.98]])
angles = np.array([-3.14])
corner = center_to_corner_box2d(center, dims, angles)
expected_corner = np.array([[[9.584485, -3.1374772], [9.582925, -4.117476],
[9.112926, -4.1167274],
[9.114486, -3.1367288]]])
assert np.allclose(corner, expected_corner)
def test_rotation_2d():
from mmdet3d.core.bbox.box_np_ops import rotation_2d
angles = np.array([-3.14])
corners = np.array([[[-0.235, -0.49], [-0.235, 0.49], [0.235, 0.49],
[0.235, -0.49]]])
corners_rotated = rotation_2d(corners, angles)
expected_corners = np.array([[[0.2357801, 0.48962511],
[0.2342193, -0.49037365],
[-0.2357801, -0.48962511],
[-0.2342193, 0.49037365]]])
assert np.allclose(corners_rotated, expected_corners)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment