Skip to content
Snippets Groups Projects
Commit 4f4c4d44 authored by StarGazer1995's avatar StarGazer1995
Browse files

setup new backbones and config file

parent 5f7b31cc
No related branches found
No related tags found
No related merge requests found
model = dict(
type='VoteNet',
backbone=dict(
type='PointNet2SASSGWithAtt',
in_channels=4,
num_points=(2048, 1024, 512, 256),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max'),
bbox_head=dict(
type='VoteHead',
vote_moudule_cfg=dict(
in_channels=256,
vote_per_seed=1,
gt_per_seed=3,
conv_channels=(256, 256),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
norm_feats=True,
vote_loss=dict(
type='ChamferDistance',
mode='l1',
reduction='none',
loss_dst_weight=10.0)),
vote_aggregation_cfg=dict(
num_point=256,
radius=0.3,
num_sample=16,
mlp_channels=[256, 128, 128, 128],
use_xyz=True,
normalize_xyz=True),
feat_channels=(128, 128),
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
objectness_loss=dict(
type='CrossEntropyLoss',
class_weight=[0.2, 0.8],
reduction='sum',
loss_weight=5.0),
center_loss=dict(
type='ChamferDistance',
mode='l2',
reduction='sum',
loss_src_weight=10.0,
loss_dst_weight=10.0),
dir_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
dir_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0),
size_class_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0),
size_res_loss=dict(
type='SmoothL1Loss', reduction='sum', loss_weight=10.0 / 3.0),
semantic_loss=dict(
type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)))
# model training and testing settings
train_cfg = dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote')
test_cfg = dict(
sample_mod='seed', nms_thr=0.25, score_thr=0.05, per_class_proposal=True)
import torch
from mmcv.runner import load_checkpoint
from torch import nn as nn
from mmdet3d.ops import PointFPModule, PointSAModule
from mmdet.models import BACKBONES
@BACKBONES.register_module()
class PointNet2SASSGWithAtt(nn.Module):
"""PointNet2 with Single-scale grouping.
Args:
in_channels (int): Input channels of point cloud.
num_points (tuple[int]): The number of points which each SA
module samples.
radius (tuple[float]): Sampling radii of each SA module.
num_samples (tuple[int]): The number of samples for ball
query in each SA module.
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
fp_channels (tuple[tuple[int]]): Out channels of each mlp in FP module.
norm_cfg (dict): Config of normalization layer.
pool_mod (str): Pool method ('max' or 'avg') for SA modules.
use_xyz (bool): Whether to use xyz as a part of features.
normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
"""
def __init__(self,
in_channels,
num_points=(2048, 1024, 512, 256),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
pool_mod='max',
use_xyz=True,
normalize_xyz=True):
super().__init__()
self.num_sa = len(sa_channels)
self.num_fp = len(fp_channels)
assert len(num_points) == len(radius) == len(num_samples) == len(
sa_channels)
assert len(sa_channels) >= len(fp_channels)
assert pool_mod in ['max', 'avg']
self.SA_modules = nn.ModuleList()
sa_in_channel = in_channels - 3 # number of channels without xyz
skip_channel_list = [sa_in_channel]
for sa_index in range(self.num_sa):
cur_sa_mlps = list(sa_channels[sa_index])
cur_sa_mlps = [sa_in_channel] + cur_sa_mlps
sa_out_channel = cur_sa_mlps[-1]
self.SA_modules.append(
PointSAModule(
num_point=num_points[sa_index],
radius=radius[sa_index],
num_sample=num_samples[sa_index],
mlp_channels=cur_sa_mlps,
norm_cfg=norm_cfg,
use_xyz=use_xyz,
pool_mod=pool_mod,
normalize_xyz=normalize_xyz))
skip_channel_list.append(sa_out_channel)
sa_in_channel = sa_out_channel
self.FP_modules = nn.ModuleList()
fp_source_channel = skip_channel_list.pop()
fp_target_channel = skip_channel_list.pop()
for fp_index in range(len(fp_channels)):
cur_fp_mlps = list(fp_channels[fp_index])
cur_fp_mlps = [fp_source_channel + fp_target_channel] + cur_fp_mlps
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
if fp_index != len(fp_channels) - 1:
fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop()
def init_weights(self, pretrained=None):
"""Initialize the weights of PointNet backbone."""
# Do not initialize the conv layers
# to follow the original implementation
if isinstance(pretrained, str):
from mmdet3d.utils import get_root_logger
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
@staticmethod
def _split_point_feats(points):
"""Split coordinates and features of input points.
Args:
points (torch.Tensor): Point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
"""
xyz = points[..., 0:3].contiguous()
if points.size(-1) > 3:
features = points[..., 3:].transpose(1, 2).contiguous()
else:
features = None
return xyz, features
def forward(self, points):
"""Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
dict[str, list[torch.Tensor]]: Outputs after SA and FP modules.
- fp_xyz (list[torch.Tensor]): The coordinates of \
each fp features.
- fp_features (list[torch.Tensor]): The features \
from each Feature Propagate Layers with attention.
- fp_indices (list[torch.Tensor]): Indices of the \
input points.
"""
xyz, features = self._split_point_feats(points)
batch, num_points = xyz.shape[:2]
indices = xyz.new_tensor(range(num_points)).unsqueeze(0).repeat(
batch, 1).long()
sa_xyz = [xyz]
sa_features = [features]
sa_indices = [indices]
for i in range(self.num_sa):
cur_xyz, cur_features, cur_indices = self.SA_modules[i](
sa_xyz[i], sa_features[i])
sa_xyz.append(cur_xyz)
sa_features.append(cur_features)
sa_indices.append(
torch.gather(sa_indices[-1], 1, cur_indices.long()))
fp_xyz = [sa_xyz[-1]]
fp_features = [sa_features[-1]]
fp_indices = [sa_indices[-1]]
for i in range(self.num_fp):
fp_features.append(self.FP_modules[i](
sa_xyz[self.num_sa - i - 1], sa_xyz[self.num_sa - i],
sa_features[self.num_sa - i - 1], fp_features[-1]))
fp_xyz.append(sa_xyz[self.num_sa - i - 1])
fp_indices.append(sa_indices[self.num_sa - i - 1])
x = nn.softmax(fp_features)
fp_features = x * fp_features
ret = dict(
fp_xyz=fp_xyz, fp_features=fp_features, fp_indices=fp_indices)
return ret
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment