Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
F
FCN-CD-PyTorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manli
FCN-CD-PyTorch
Commits
36e94e06
Commit
36e94e06
authored
5 years ago
by
Bobholamovic
Browse files
Options
Downloads
Patches
Plain Diff
Happy New Year
parent
4d7e2b88
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
src/core/factories.py
+94
-43
94 additions, 43 deletions
src/core/factories.py
src/core/trainers.py
+10
-6
10 additions, 6 deletions
src/core/trainers.py
src/utils/metrics.py
+61
-20
61 additions, 20 deletions
src/utils/metrics.py
with
165 additions
and
69 deletions
src/core/factories.py
+
94
−
43
View file @
36e94e06
...
...
@@ -12,6 +12,7 @@ import constants
import
utils.metrics
as
metrics
from
utils.misc
import
R
class
_Desc
:
def
__init__
(
self
,
key
):
self
.
key
=
key
...
...
@@ -26,15 +27,7 @@ class _Desc:
def
_func_deco
(
func_name
):
def
_wrapper
(
self
,
*
args
):
# TODO: Add key argument support
try
:
# Dispatch type 1
ret
=
tuple
(
getattr
(
ins
,
func_name
)(
*
args
)
for
ins
in
self
)
except
Exception
:
# Dispatch type 2
if
len
(
args
)
>
1
or
(
len
(
args
[
0
])
!=
len
(
self
)):
raise
ret
=
tuple
(
getattr
(
i
,
func_name
)(
a
)
for
i
,
a
in
zip
(
self
,
args
[
0
]))
return
ret
return
tuple
(
getattr
(
ins
,
func_name
)(
*
args
)
for
ins
in
self
)
return
_wrapper
...
...
@@ -45,6 +38,16 @@ def _generator_deco(func_name):
return
_wrapper
def
_mark
(
func
):
func
.
__marked__
=
True
return
func
def
_unmark
(
func
):
func
.
__marked__
=
False
return
func
# Duck typing
class
Duck
(
tuple
):
__ducktype__
=
object
...
...
@@ -60,6 +63,9 @@ class DuckMeta(type):
for
k
,
v
in
getmembers
(
bases
[
0
]):
if
k
.
startswith
(
'
__
'
):
continue
if
k
in
attrs
and
hasattr
(
attrs
[
k
],
'
__marked__
'
):
if
attrs
[
k
].
__marked__
:
continue
if
isgeneratorfunction
(
v
):
attrs
[
k
]
=
_generator_deco
(
k
)
elif
isfunction
(
v
):
...
...
@@ -71,14 +77,48 @@ class DuckMeta(type):
class
DuckModel
(
nn
.
Module
,
metaclass
=
DuckMeta
):
pass
DELIM
=
'
:
'
@_mark
def
load_state_dict
(
self
,
state_dict
):
dicts
=
[
dict
()
for
_
in
range
(
len
(
self
))]
for
k
,
v
in
state_dict
.
items
():
i
,
*
k
=
k
.
split
(
self
.
DELIM
)
k
=
self
.
DELIM
.
join
(
k
)
i
=
int
(
i
)
dicts
[
i
][
k
]
=
v
for
i
in
range
(
len
(
self
)):
self
[
i
].
load_state_dict
(
dicts
[
i
])
@_mark
def
state_dict
(
self
):
dict_
=
dict
()
for
i
,
ins
in
enumerate
(
self
):
dict_
.
update
({
self
.
DELIM
.
join
([
str
(
i
),
key
]):
val
for
key
,
val
in
ins
.
state_dict
().
items
()})
return
dict_
class
DuckOptimizer
(
torch
.
optim
.
Optimizer
,
metaclass
=
DuckMeta
):
DELIM
=
'
:
'
@property
def
param_groups
(
self
):
return
list
(
chain
.
from_iterable
(
ins
.
param_groups
for
ins
in
self
))
@_mark
def
state_dict
(
self
):
dict_
=
dict
()
for
i
,
ins
in
enumerate
(
self
):
dict_
.
update
({
self
.
DELIM
.
join
([
str
(
i
),
key
]):
val
for
key
,
val
in
ins
.
state_dict
().
items
()})
return
dict_
@_mark
def
load_state_dict
(
self
,
state_dict
):
dicts
=
[
dict
()
for
_
in
range
(
len
(
self
))]
for
k
,
v
in
state_dict
.
items
():
i
,
*
k
=
k
.
split
(
self
.
DELIM
)
k
=
self
.
DELIM
.
join
(
k
)
i
=
int
(
i
)
dicts
[
i
][
k
]
=
v
for
i
in
range
(
len
(
self
)):
self
[
i
].
load_state_dict
(
dicts
[
i
])
class
DuckCriterion
(
nn
.
Module
,
metaclass
=
DuckMeta
):
pass
...
...
@@ -112,7 +152,8 @@ def single_model_factory(model_name, C):
def
single_optim_factory
(
optim_name
,
params
,
C
):
name
=
optim_name
.
strip
().
upper
()
optim_name
=
optim_name
.
strip
()
name
=
optim_name
.
upper
()
if
name
==
'
ADAM
'
:
return
torch
.
optim
.
Adam
(
params
,
...
...
@@ -133,6 +174,7 @@ def single_optim_factory(optim_name, params, C):
def
single_critn_factory
(
critn_name
,
C
):
import
losses
critn_name
=
critn_name
.
strip
()
try
:
criterion
,
params
=
{
'
L1
'
:
(
nn
.
L1Loss
,
()),
...
...
@@ -145,6 +187,19 @@ def single_critn_factory(critn_name, C):
raise
NotImplementedError
(
"
{} is not a supported criterion type
"
.
format
(
critn_name
))
def
_get_basic_configs
(
ds_name
,
C
):
if
ds_name
==
'
OSCD
'
:
return
dict
(
root
=
constants
.
IMDB_OSCD
)
elif
ds_name
.
startswith
(
'
AC
'
):
return
dict
(
root
=
constants
.
IMDB_AirChange
)
else
:
return
dict
()
def
single_train_ds_factory
(
ds_name
,
C
):
from
data.augmentation
import
Compose
,
Crop
,
Flip
ds_name
=
ds_name
.
strip
()
...
...
@@ -155,21 +210,13 @@ def single_train_ds_factory(ds_name, C):
transforms
=
(
Compose
(
Crop
(
C
.
crop_size
),
Flip
()),
None
,
None
),
repeats
=
C
.
repeats
)
if
ds_name
==
'
OSCD
'
:
configs
.
update
(
dict
(
root
=
constants
.
IMDB_OSCD
)
)
elif
ds_name
.
startswith
(
'
AC
'
):
configs
.
update
(
dict
(
root
=
constants
.
IMDB_AirChange
)
)
else
:
pass
# Update some common configurations
configs
.
update
(
_get_basic_configs
(
ds_name
,
C
))
# Set phase-specific ones
pass
dataset_obj
=
dataset
(
**
configs
)
return
data
.
DataLoader
(
...
...
@@ -190,21 +237,13 @@ def single_val_ds_factory(ds_name, C):
transforms
=
(
None
,
None
,
None
),
repeats
=
1
)
if
ds_name
==
'
OSCD
'
:
configs
.
update
(
dict
(
root
=
constants
.
IMDB_OSCD
)
)
elif
ds_name
.
startswith
(
'
AC
'
):
configs
.
update
(
dict
(
root
=
constants
.
IMDB_AirChange
)
)
else
:
pass
# Update some common configurations
configs
.
update
(
_get_basic_configs
(
ds_name
,
C
))
# Set phase-specific ones
pass
dataset_obj
=
dataset
(
**
configs
)
# Create eval set
...
...
@@ -229,12 +268,24 @@ def model_factory(model_names, C):
return
single_model_factory
(
model_names
,
C
)
def
optim_factory
(
optim_names
,
param
s
,
C
):
def
optim_factory
(
optim_names
,
model
s
,
C
):
name_list
=
_parse_input_names
(
optim_names
)
if
len
(
name_list
)
>
1
:
return
DuckOptimizer
(
*
(
single_optim_factory
(
name
,
params
,
C
)
for
name
in
name_list
))
num_models
=
len
(
models
)
if
isinstance
(
models
,
DuckModel
)
else
1
if
len
(
name_list
)
!=
num_models
:
raise
ValueError
(
"
the number of optimizers does not match the number of models
"
)
if
num_models
>
1
:
optims
=
[]
for
name
,
model
in
zip
(
name_list
,
models
):
param_groups
=
[{
'
params
'
:
module
.
parameters
(),
'
name
'
:
module_name
}
for
module_name
,
module
in
model
.
named_children
()]
optims
.
append
(
single_optim_factory
(
name
,
param_groups
,
C
))
return
DuckOptimizer
(
*
optims
)
else
:
return
single_optim_factory
(
optim_names
,
params
,
C
)
return
single_optim_factory
(
optim_names
,
[{
'
params
'
:
module
.
parameters
(),
'
name
'
:
module_name
}
for
module_name
,
module
in
models
.
named_children
()],
C
)
def
critn_factory
(
critn_names
,
C
):
...
...
This diff is collapsed.
Click to expand it.
src/core/trainers.py
+
10
−
6
View file @
36e94e06
...
...
@@ -33,8 +33,8 @@ class Trainer:
self
.
lr
=
float
(
context
.
lr
)
self
.
save
=
context
.
save_on
or
context
.
out_dir
self
.
out_dir
=
context
.
out_dir
self
.
trace_freq
=
context
.
trace_freq
self
.
device
=
context
.
device
self
.
trace_freq
=
int
(
context
.
trace_freq
)
self
.
device
=
torch
.
device
(
context
.
device
)
self
.
suffix_off
=
context
.
suffix_off
for
k
,
v
in
sorted
(
self
.
ctx
.
items
()):
...
...
@@ -44,7 +44,7 @@ class Trainer:
self
.
model
.
to
(
self
.
device
)
self
.
criterion
=
critn_factory
(
criterion
,
context
)
self
.
criterion
.
to
(
self
.
device
)
self
.
optimizer
=
optim_factory
(
optimizer
,
self
.
model
.
parameters
()
,
context
)
self
.
optimizer
=
optim_factory
(
optimizer
,
self
.
model
,
context
)
self
.
metrics
=
metric_factory
(
context
.
metrics
,
context
)
self
.
train_loader
=
data_factory
(
dataset
,
'
train
'
,
context
)
...
...
@@ -74,10 +74,14 @@ class Trainer:
# Train for one epoch
self
.
train_epoch
()
# Clear the history of metric objects
for
m
in
self
.
metrics
:
m
.
reset
()
# Evaluate the model on validation set
self
.
logger
.
show_nl
(
"
Validate
"
)
acc
=
self
.
validate_epoch
(
epoch
=
epoch
,
store
=
self
.
save
)
is_best
=
acc
>
max_acc
if
is_best
:
max_acc
=
acc
...
...
@@ -250,7 +254,7 @@ class CDTrainer(Trainer):
losses
.
update
(
loss
.
item
(),
n
=
self
.
batch_size
)
# Convert to numpy arrays
CM
=
to_array
(
torch
.
argmax
(
prob
,
1
)).
astype
(
'
uint8
'
)
CM
=
to_array
(
torch
.
argmax
(
prob
[
0
]
,
0
)).
astype
(
'
uint8
'
)
label
=
to_array
(
label
[
0
]).
astype
(
'
uint8
'
)
for
m
in
self
.
metrics
:
m
.
update
(
CM
,
label
)
...
...
@@ -267,6 +271,6 @@ class CDTrainer(Trainer):
self
.
logger
.
dump
(
desc
)
if
store
:
self
.
save_image
(
name
[
0
],
(
CM
*
255
).
squeeze
(
-
1
)
,
epoch
)
self
.
save_image
(
name
[
0
],
CM
*
255
,
epoch
)
return
self
.
metrics
[
0
].
avg
if
len
(
self
.
metrics
)
>
0
else
max
(
1.0
-
losses
.
avg
,
self
.
_init_max_acc
)
\ No newline at end of file
This diff is collapsed.
Click to expand it.
src/utils/metrics.py
+
61
−
20
View file @
36e94e06
from
functools
import
partial
import
numpy
as
np
from
sklearn
import
metrics
class
AverageMeter
:
def
__init__
(
self
,
callback
=
None
):
super
().
__init__
()
self
.
callback
=
callback
if
callback
is
not
None
:
self
.
compute
=
callback
self
.
reset
()
def
compute
(
self
,
*
args
):
if
self
.
callback
is
not
None
:
return
self
.
callback
(
*
args
)
elif
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
return
args
[
0
]
else
:
raise
NotImplementedError
def
reset
(
self
):
self
.
val
=
0
.0
self
.
avg
=
0
.0
self
.
sum
=
0
.0
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
*
args
,
n
=
1
):
...
...
@@ -27,36 +29,75 @@ class AverageMeter:
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
__repr__
(
self
):
return
'
val: {} avg: {} cnt: {}
'
.
format
(
self
.
val
,
self
.
avg
,
self
.
count
)
# These metrics only for numpy arrays
class
Metric
(
AverageMeter
):
__name__
=
'
Metric
'
def
__init__
(
self
,
callback
,
**
configs
):
super
().
__init__
(
callback
)
self
.
configs
=
configs
def
__init__
(
self
,
n_classes
=
2
,
mode
=
'
accum
'
,
reduction
=
'
binary
'
):
super
().
__init__
(
None
)
self
.
_cm
=
AverageMeter
(
partial
(
metrics
.
confusion_matrix
,
labels
=
np
.
arange
(
n_classes
)))
assert
mode
in
(
'
accum
'
,
'
separ
'
)
self
.
mode
=
mode
assert
reduction
in
(
'
mean
'
,
'
none
'
,
'
binary
'
)
if
reduction
==
'
binary
'
and
n_classes
!=
2
:
raise
ValueError
(
"
binary reduction only works in 2-class cases
"
)
self
.
reduction
=
reduction
def
compute
(
self
,
pred
,
true
):
return
self
.
callback
(
true
.
ravel
(),
pred
.
ravel
(),
**
self
.
configs
)
def
_compute
(
self
,
cm
):
raise
NotImplementedError
def
compute
(
self
,
cm
):
if
self
.
reduction
==
'
none
'
:
# Do not reduce size
return
self
.
_compute
(
cm
)
elif
self
.
reduction
==
'
mean
'
:
# Micro averaging
return
self
.
_compute
(
cm
).
mean
()
else
:
# The pos_class be 1
return
self
.
_compute
(
cm
)[
1
]
def
update
(
self
,
pred
,
true
,
n
=
1
):
# Note that this is no thread-safe
self
.
_cm
.
update
(
true
.
ravel
(),
pred
.
ravel
())
if
self
.
mode
==
'
accum
'
:
cm
=
self
.
_cm
.
sum
elif
self
.
mode
==
'
separ
'
:
cm
=
self
.
_cm
.
val
else
:
raise
NotImplementedError
super
().
update
(
cm
,
n
=
n
)
def
__repr__
(
self
):
return
self
.
__name__
+
'
'
+
super
().
__repr__
()
class
Precision
(
Metric
):
__name__
=
'
Prec.
'
def
_
_init__
(
self
,
**
configs
):
super
().
__init__
(
metrics
.
precision_score
,
**
configs
)
def
_
compute
(
self
,
cm
):
return
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
0
)
)
class
Recall
(
Metric
):
__name__
=
'
Recall
'
def
_
_init__
(
self
,
**
configs
):
super
().
__init__
(
metrics
.
recall_score
,
**
configs
)
def
_
compute
(
self
,
cm
):
return
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
1
)
)
class
Accuracy
(
Metric
):
__name__
=
'
OA
'
def
__init__
(
self
,
**
configs
):
super
().
__init__
(
metrics
.
accuracy_score
,
**
configs
)
def
__init__
(
self
,
n_classes
=
2
,
mode
=
'
accum
'
):
super
().
__init__
(
n_classes
=
n_classes
,
mode
=
mode
,
reduction
=
'
none
'
)
def
_compute
(
self
,
cm
):
return
np
.
nan_to_num
(
np
.
diag
(
cm
).
sum
()
/
cm
.
sum
())
class
F1Score
(
Metric
):
__name__
=
'
F1
'
def
__init__
(
self
,
**
configs
):
super
().
__init__
(
metrics
.
f1_score
,
**
configs
)
\ No newline at end of file
def
_compute
(
self
,
cm
):
prec
=
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
0
))
recall
=
np
.
nan_to_num
(
np
.
diag
(
cm
)
/
cm
.
sum
(
axis
=
1
))
return
np
.
nan_to_num
(
2
*
(
prec
*
recall
)
/
(
prec
+
recall
))
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment