Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
F
FCN-CD-PyTorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
manli
FCN-CD-PyTorch
Merge requests
!2
Update outdated code
Code
Review changes
Check out branch
Download
Patches
Plain diff
Open
Update outdated code
github/fork/Bobholamovic/master
into
master
Overview
0
Commits
17
Pipelines
0
Changes
1
Open
manli
requested to merge
github/fork/Bobholamovic/master
into
master
4 years ago
Overview
0
Commits
17
Pipelines
0
Changes
1
Expand
Created by: Bobholamovic
Fix known bugs and refactor the framework.
0
0
Merge request reports
Viewing commit
d6d3faca
Prev
Next
Show latest version
1 file
+
1
−
0
Inline
Compare changes
Side-by-side
Inline
Show whitespace changes
Show one file at a time
d6d3faca
Add cuda manual seed
· d6d3faca
Bobholamovic
authored
4 years ago
src/train.py
+
39
−
134
Options
#!/usr/bin/env python3
import
argparse
import
os
import
shutil
import
random
import
ast
from
os.path
import
basename
,
exists
,
splitext
import
os.path
as
osp
import
torch
import
torch.backends.cudnn
as
cudnn
import
numpy
as
np
import
yaml
from
core.trainers
import
CDTrainer
from
utils.misc
import
OutPathGetter
,
Logger
,
register
def
read_config
(
config_path
):
with
open
(
config_path
,
'
r
'
)
as
f
:
cfg
=
yaml
.
load
(
f
.
read
(),
Loader
=
yaml
.
FullLoader
)
return
cfg
or
{}
def
parse_config
(
cfg_name
,
cfg
):
# Parse the name of config file
sp
=
splitext
(
cfg_name
)[
0
].
split
(
'
_
'
)
if
len
(
sp
)
>=
2
:
cfg
.
setdefault
(
'
tag
'
,
sp
[
1
])
cfg
.
setdefault
(
'
suffix
'
,
'
_
'
.
join
(
sp
[
2
:]))
return
cfg
def
parse_args
():
# Training settings
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
cmd
'
,
choices
=
[
'
train
'
,
'
val
'
])
# Data
# Common
group_data
=
parser
.
add_argument_group
(
'
data
'
)
group_data
.
add_argument
(
'
-d
'
,
'
--dataset
'
,
type
=
str
,
default
=
'
OSCD
'
)
group_data
.
add_argument
(
'
-p
'
,
'
--crop-size
'
,
type
=
int
,
default
=
256
,
metavar
=
'
P
'
,
help
=
'
patch size (default: %(default)s)
'
)
group_data
.
add_argument
(
'
--num-workers
'
,
type
=
int
,
default
=
8
)
group_data
.
add_argument
(
'
--repeats
'
,
type
=
int
,
default
=
100
)
# Optimizer
group_optim
=
parser
.
add_argument_group
(
'
optimizer
'
)
group_optim
.
add_argument
(
'
--optimizer
'
,
type
=
str
,
default
=
'
Adam
'
)
group_optim
.
add_argument
(
'
--lr
'
,
type
=
float
,
default
=
1e-4
,
metavar
=
'
LR
'
,
help
=
'
learning rate (default: %(default)s)
'
)
group_optim
.
add_argument
(
'
--lr-mode
'
,
type
=
str
,
default
=
'
const
'
)
group_optim
.
add_argument
(
'
--weight-decay
'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'
W
'
,
help
=
'
weight decay (default: %(default)s)
'
)
group_optim
.
add_argument
(
'
--step
'
,
type
=
int
,
default
=
200
)
# Training related
group_train
=
parser
.
add_argument_group
(
'
training related
'
)
group_train
.
add_argument
(
'
--batch-size
'
,
type
=
int
,
default
=
8
,
metavar
=
'
B
'
,
help
=
'
input batch size for training (default: %(default)s)
'
)
group_train
.
add_argument
(
'
--num-epochs
'
,
type
=
int
,
default
=
1000
,
metavar
=
'
NE
'
,
help
=
'
number of epochs to train (default: %(default)s)
'
)
group_train
.
add_argument
(
'
--load-optim
'
,
action
=
'
store_true
'
)
group_train
.
add_argument
(
'
--resume
'
,
default
=
''
,
type
=
str
,
metavar
=
'
PATH
'
,
help
=
'
path to latest checkpoint
'
)
group_train
.
add_argument
(
'
--anew
'
,
action
=
'
store_true
'
,
help
=
'
clear history and start from epoch 0 with the checkpoint loaded
'
)
group_train
.
add_argument
(
'
--trace-freq
'
,
type
=
int
,
default
=
50
)
group_train
.
add_argument
(
'
--device
'
,
type
=
str
,
default
=
'
cpu
'
)
group_train
.
add_argument
(
'
--metrics
'
,
type
=
str
,
default
=
'
F1Score+Accuracy+Recall+Precision
'
)
# Experiment
group_exp
=
parser
.
add_argument_group
(
'
experiment related
'
)
group_exp
.
add_argument
(
'
--exp-dir
'
,
default
=
'
../exp/
'
)
group_exp
.
add_argument
(
'
-o
'
,
'
--out-dir
'
,
default
=
''
)
group_exp
.
add_argument
(
'
--tag
'
,
type
=
str
,
default
=
''
)
group_exp
.
add_argument
(
'
--suffix
'
,
type
=
str
,
default
=
''
)
group_exp
.
add_argument
(
'
--exp-config
'
,
type
=
str
,
default
=
''
)
group_exp
.
add_argument
(
'
--save-on
'
,
action
=
'
store_true
'
)
group_exp
.
add_argument
(
'
--log-off
'
,
action
=
'
store_true
'
)
group_exp
.
add_argument
(
'
--suffix-off
'
,
action
=
'
store_true
'
)
# Criterion
group_critn
=
parser
.
add_argument_group
(
'
criterion related
'
)
group_critn
.
add_argument
(
'
--criterion
'
,
type
=
str
,
default
=
'
NLL
'
)
group_critn
.
add_argument
(
'
--weights
'
,
type
=
str
,
default
=
(
1.0
,
1.0
))
# Model
group_model
=
parser
.
add_argument_group
(
'
model
'
)
group_model
.
add_argument
(
'
--model
'
,
type
=
str
,
default
=
'
siamunet_conc
'
)
group_model
.
add_argument
(
'
--num-feats-in
'
,
type
=
int
,
default
=
13
)
args
=
parser
.
parse_args
()
if
exists
(
args
.
exp_config
):
cfg
=
read_config
(
args
.
exp_config
)
cfg
=
parse_config
(
basename
(
args
.
exp_config
),
cfg
)
# Settings from cfg file overwrite those in args
# Note that the non-default values will not be affected
parser
.
set_defaults
(
**
cfg
)
# Reset part of the default values
args
=
parser
.
parse_args
()
# Parse again
# Handle args.weights
if
isinstance
(
args
.
weights
,
str
):
args
.
weights
=
ast
.
literal_eval
(
args
.
weights
)
args
.
weights
=
tuple
(
args
.
weights
)
return
args
def
set_gpc_and_logger
(
args
):
gpc
=
OutPathGetter
(
root
=
os
.
path
.
join
(
args
.
exp_dir
,
args
.
tag
),
suffix
=
args
.
suffix
)
log_dir
=
''
if
args
.
log_off
else
gpc
.
get_dir
(
'
log
'
)
logger
=
Logger
(
scrn
=
True
,
log_dir
=
log_dir
,
phase
=
args
.
cmd
)
register
(
'
GPC
'
,
gpc
)
register
(
'
LOGGER
'
,
logger
)
return
gpc
,
logger
from
core.misc
import
R
from
core.config
import
parse_args
def
main
():
args
=
parse_args
()
gpc
,
logger
=
set_gpc_and_logger
(
args
)
if
args
.
exp_config
:
# Make a copy of the config file
cfg_path
=
gpc
.
get_path
(
'
root
'
,
basename
(
args
.
exp_config
),
suffix
=
False
)
shutil
.
copy
(
args
.
exp_config
,
cfg_path
)
# Set random seed
RNG_SEED
=
1
random
.
seed
(
RNG_SEED
)
np
.
random
.
seed
(
RNG_SEED
)
torch
.
manual_seed
(
RNG_SEED
)
torch
.
cuda
.
manual_seed
(
RNG_SEED
)
cudnn
.
deterministic
=
True
cudnn
.
benchmark
=
False
try
:
trainer
=
CDTrainer
(
args
.
model
,
args
.
dataset
,
args
.
optimizer
,
args
)
trainer
.
run
()
except
BaseException
as
e
:
import
traceback
# Catch ALL kinds of exceptions
logger
.
fatal
(
traceback
.
format_exc
())
exit
(
1
)
# Parse commandline arguments
def
parser_configurator
(
parser
):
parser
.
add_argument
(
'
--crop_size
'
,
type
=
int
,
default
=
256
,
metavar
=
'
P
'
,
help
=
"
patch size (default: %(default)s)
"
)
parser
.
add_argument
(
'
--tb_on
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--tb_intvl
'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'
--suffix_off
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--lr_mode
'
,
type
=
str
,
default
=
'
const
'
)
parser
.
add_argument
(
'
--step
'
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
'
--save_on
'
,
action
=
'
store_true
'
)
parser
.
add_argument
(
'
--out_dir
'
,
default
=
''
)
parser
.
add_argument
(
'
--val_iters
'
,
type
=
int
,
default
=
16
)
return
parser
args
=
parse_args
(
parser_configurator
)
trainer
=
R
[
'
Trainer_switcher
'
](
args
)
if
trainer
is
not
None
:
if
args
[
'
exp_config
'
]:
# Make a copy of the config file
cfg_path
=
osp
.
join
(
trainer
.
gpc
.
root
,
osp
.
basename
(
args
[
'
exp_config
'
]))
shutil
.
copy
(
args
[
'
exp_config
'
],
cfg_path
)
try
:
trainer
.
run
()
except
BaseException
as
e
:
import
traceback
# Catch ALL kinds of exceptions
trainer
.
logger
.
fatal
(
traceback
.
format_exc
())
if
args
[
'
debug_on
'
]:
breakpoint
()
exit
(
1
)
else
:
raise
NotImplementedError
(
"
Cannot find an appropriate trainer!
"
)
if
__name__
==
'
__main__
'
:
main
()
\ No newline at end of file
Loading