Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
mmdetection3d
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manli
mmdetection3d
Commits
4f4c4d44
Commit
4f4c4d44
authored
4 years ago
by
StarGazer1995
Browse files
Options
Downloads
Patches
Plain Diff
setup new backbones and config file
parent
5f7b31cc
No related branches found
No related tags found
No related merge requests found
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
configs/_base_/models/votenetwithatt.py
+63
-0
63 additions, 0 deletions
configs/_base_/models/votenetwithatt.py
mmdet3d/models/backbones/pointnet2_sa_ssg_with_attention.py
+165
-0
165 additions, 0 deletions
mmdet3d/models/backbones/pointnet2_sa_ssg_with_attention.py
with
228 additions
and
0 deletions
configs/_base_/models/votenetwithatt.py
0 → 100644
+
63
−
0
View file @
4f4c4d44
model
=
dict
(
type
=
'
VoteNet
'
,
backbone
=
dict
(
type
=
'
PointNet2SASSGWithAtt
'
,
in_channels
=
4
,
num_points
=
(
2048
,
1024
,
512
,
256
),
radius
=
(
0.2
,
0.4
,
0.8
,
1.2
),
num_samples
=
(
64
,
32
,
16
,
16
),
sa_channels
=
((
64
,
64
,
128
),
(
128
,
128
,
256
),
(
128
,
128
,
256
),
(
128
,
128
,
256
)),
fp_channels
=
((
256
,
256
),
(
256
,
256
)),
norm_cfg
=
dict
(
type
=
'
BN2d
'
),
pool_mod
=
'
max
'
),
bbox_head
=
dict
(
type
=
'
VoteHead
'
,
vote_moudule_cfg
=
dict
(
in_channels
=
256
,
vote_per_seed
=
1
,
gt_per_seed
=
3
,
conv_channels
=
(
256
,
256
),
conv_cfg
=
dict
(
type
=
'
Conv1d
'
),
norm_cfg
=
dict
(
type
=
'
BN1d
'
),
norm_feats
=
True
,
vote_loss
=
dict
(
type
=
'
ChamferDistance
'
,
mode
=
'
l1
'
,
reduction
=
'
none
'
,
loss_dst_weight
=
10.0
)),
vote_aggregation_cfg
=
dict
(
num_point
=
256
,
radius
=
0.3
,
num_sample
=
16
,
mlp_channels
=
[
256
,
128
,
128
,
128
],
use_xyz
=
True
,
normalize_xyz
=
True
),
feat_channels
=
(
128
,
128
),
conv_cfg
=
dict
(
type
=
'
Conv1d
'
),
norm_cfg
=
dict
(
type
=
'
BN1d
'
),
objectness_loss
=
dict
(
type
=
'
CrossEntropyLoss
'
,
class_weight
=
[
0.2
,
0.8
],
reduction
=
'
sum
'
,
loss_weight
=
5.0
),
center_loss
=
dict
(
type
=
'
ChamferDistance
'
,
mode
=
'
l2
'
,
reduction
=
'
sum
'
,
loss_src_weight
=
10.0
,
loss_dst_weight
=
10.0
),
dir_class_loss
=
dict
(
type
=
'
CrossEntropyLoss
'
,
reduction
=
'
sum
'
,
loss_weight
=
1.0
),
dir_res_loss
=
dict
(
type
=
'
SmoothL1Loss
'
,
reduction
=
'
sum
'
,
loss_weight
=
10.0
),
size_class_loss
=
dict
(
type
=
'
CrossEntropyLoss
'
,
reduction
=
'
sum
'
,
loss_weight
=
1.0
),
size_res_loss
=
dict
(
type
=
'
SmoothL1Loss
'
,
reduction
=
'
sum
'
,
loss_weight
=
10.0
/
3.0
),
semantic_loss
=
dict
(
type
=
'
CrossEntropyLoss
'
,
reduction
=
'
sum
'
,
loss_weight
=
1.0
)))
# model training and testing settings
train_cfg
=
dict
(
pos_distance_thr
=
0.3
,
neg_distance_thr
=
0.6
,
sample_mod
=
'
vote
'
)
test_cfg
=
dict
(
sample_mod
=
'
seed
'
,
nms_thr
=
0.25
,
score_thr
=
0.05
,
per_class_proposal
=
True
)
This diff is collapsed.
Click to expand it.
mmdet3d/models/backbones/pointnet2_sa_ssg_with_attention.py
0 → 100644
+
165
−
0
View file @
4f4c4d44
import
torch
from
mmcv.runner
import
load_checkpoint
from
torch
import
nn
as
nn
from
mmdet3d.ops
import
PointFPModule
,
PointSAModule
from
mmdet.models
import
BACKBONES
@BACKBONES.register_module
()
class
PointNet2SASSGWithAtt
(
nn
.
Module
):
"""
PointNet2 with Single-scale grouping.
Args:
in_channels (int): Input channels of point cloud.
num_points (tuple[int]): The number of points which each SA
module samples.
radius (tuple[float]): Sampling radii of each SA module.
num_samples (tuple[int]): The number of samples for ball
query in each SA module.
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
fp_channels (tuple[tuple[int]]): Out channels of each mlp in FP module.
norm_cfg (dict): Config of normalization layer.
pool_mod (str): Pool method (
'
max
'
or
'
avg
'
) for SA modules.
use_xyz (bool): Whether to use xyz as a part of features.
normalize_xyz (bool): Whether to normalize xyz with radii in
each SA module.
"""
def
__init__
(
self
,
in_channels
,
num_points
=
(
2048
,
1024
,
512
,
256
),
radius
=
(
0.2
,
0.4
,
0.8
,
1.2
),
num_samples
=
(
64
,
32
,
16
,
16
),
sa_channels
=
((
64
,
64
,
128
),
(
128
,
128
,
256
),
(
128
,
128
,
256
),
(
128
,
128
,
256
)),
fp_channels
=
((
256
,
256
),
(
256
,
256
)),
norm_cfg
=
dict
(
type
=
'
BN2d
'
),
pool_mod
=
'
max
'
,
use_xyz
=
True
,
normalize_xyz
=
True
):
super
().
__init__
()
self
.
num_sa
=
len
(
sa_channels
)
self
.
num_fp
=
len
(
fp_channels
)
assert
len
(
num_points
)
==
len
(
radius
)
==
len
(
num_samples
)
==
len
(
sa_channels
)
assert
len
(
sa_channels
)
>=
len
(
fp_channels
)
assert
pool_mod
in
[
'
max
'
,
'
avg
'
]
self
.
SA_modules
=
nn
.
ModuleList
()
sa_in_channel
=
in_channels
-
3
# number of channels without xyz
skip_channel_list
=
[
sa_in_channel
]
for
sa_index
in
range
(
self
.
num_sa
):
cur_sa_mlps
=
list
(
sa_channels
[
sa_index
])
cur_sa_mlps
=
[
sa_in_channel
]
+
cur_sa_mlps
sa_out_channel
=
cur_sa_mlps
[
-
1
]
self
.
SA_modules
.
append
(
PointSAModule
(
num_point
=
num_points
[
sa_index
],
radius
=
radius
[
sa_index
],
num_sample
=
num_samples
[
sa_index
],
mlp_channels
=
cur_sa_mlps
,
norm_cfg
=
norm_cfg
,
use_xyz
=
use_xyz
,
pool_mod
=
pool_mod
,
normalize_xyz
=
normalize_xyz
))
skip_channel_list
.
append
(
sa_out_channel
)
sa_in_channel
=
sa_out_channel
self
.
FP_modules
=
nn
.
ModuleList
()
fp_source_channel
=
skip_channel_list
.
pop
()
fp_target_channel
=
skip_channel_list
.
pop
()
for
fp_index
in
range
(
len
(
fp_channels
)):
cur_fp_mlps
=
list
(
fp_channels
[
fp_index
])
cur_fp_mlps
=
[
fp_source_channel
+
fp_target_channel
]
+
cur_fp_mlps
self
.
FP_modules
.
append
(
PointFPModule
(
mlp_channels
=
cur_fp_mlps
))
if
fp_index
!=
len
(
fp_channels
)
-
1
:
fp_source_channel
=
cur_fp_mlps
[
-
1
]
fp_target_channel
=
skip_channel_list
.
pop
()
def
init_weights
(
self
,
pretrained
=
None
):
"""
Initialize the weights of PointNet backbone.
"""
# Do not initialize the conv layers
# to follow the original implementation
if
isinstance
(
pretrained
,
str
):
from
mmdet3d.utils
import
get_root_logger
logger
=
get_root_logger
()
load_checkpoint
(
self
,
pretrained
,
strict
=
False
,
logger
=
logger
)
@staticmethod
def
_split_point_feats
(
points
):
"""
Split coordinates and features of input points.
Args:
points (torch.Tensor): Point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
torch.Tensor: Coordinates of input points.
torch.Tensor: Features of input points.
"""
xyz
=
points
[...,
0
:
3
].
contiguous
()
if
points
.
size
(
-
1
)
>
3
:
features
=
points
[...,
3
:].
transpose
(
1
,
2
).
contiguous
()
else
:
features
=
None
return
xyz
,
features
def
forward
(
self
,
points
):
"""
Forward pass.
Args:
points (torch.Tensor): point coordinates with features,
with shape (B, N, 3 + input_feature_dim).
Returns:
dict[str, list[torch.Tensor]]: Outputs after SA and FP modules.
- fp_xyz (list[torch.Tensor]): The coordinates of
\
each fp features.
- fp_features (list[torch.Tensor]): The features
\
from each Feature Propagate Layers with attention.
- fp_indices (list[torch.Tensor]): Indices of the
\
input points.
"""
xyz
,
features
=
self
.
_split_point_feats
(
points
)
batch
,
num_points
=
xyz
.
shape
[:
2
]
indices
=
xyz
.
new_tensor
(
range
(
num_points
)).
unsqueeze
(
0
).
repeat
(
batch
,
1
).
long
()
sa_xyz
=
[
xyz
]
sa_features
=
[
features
]
sa_indices
=
[
indices
]
for
i
in
range
(
self
.
num_sa
):
cur_xyz
,
cur_features
,
cur_indices
=
self
.
SA_modules
[
i
](
sa_xyz
[
i
],
sa_features
[
i
])
sa_xyz
.
append
(
cur_xyz
)
sa_features
.
append
(
cur_features
)
sa_indices
.
append
(
torch
.
gather
(
sa_indices
[
-
1
],
1
,
cur_indices
.
long
()))
fp_xyz
=
[
sa_xyz
[
-
1
]]
fp_features
=
[
sa_features
[
-
1
]]
fp_indices
=
[
sa_indices
[
-
1
]]
for
i
in
range
(
self
.
num_fp
):
fp_features
.
append
(
self
.
FP_modules
[
i
](
sa_xyz
[
self
.
num_sa
-
i
-
1
],
sa_xyz
[
self
.
num_sa
-
i
],
sa_features
[
self
.
num_sa
-
i
-
1
],
fp_features
[
-
1
]))
fp_xyz
.
append
(
sa_xyz
[
self
.
num_sa
-
i
-
1
])
fp_indices
.
append
(
sa_indices
[
self
.
num_sa
-
i
-
1
])
x
=
nn
.
softmax
(
fp_features
)
fp_features
=
x
*
fp_features
ret
=
dict
(
fp_xyz
=
fp_xyz
,
fp_features
=
fp_features
,
fp_indices
=
fp_indices
)
return
ret
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment