Skip to content
Snippets Groups Projects
Commit 15bb43fb authored by Bobholamovic's avatar Bobholamovic
Browse files

Hello Github

parents
No related branches found
No related tags found
No related merge requests found
# Adapted from https://github.com/github/gitignore/blob/master/Python.gitignore
*.code-workspace
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Config files
config*.yaml
!/config_base.yaml
# Experiment folder
/exp/
\ No newline at end of file
[submodule "src/models"]
path = src/models
url = git@github.com:rcdaudt/fully_convolutional_change_detection.git
# Basic configurations
# Data
# Common
dataset: OSCD
crop_size: 256
num_workers: 1
repeats: 1000
# Optimizer
optimizer: Adam
lr: 0.001
lr_mode: const
weight_decay: 0.0001
step: 2
# Training related
batch_size: 32
num_epochs: 20
resume: ''
load_optim: True
anew: False
trace_freq: 1
device: cuda
metrics: 'F1Score+Accuracy+Recall+Precision'
# Experiment
exp_dir: ../exp/
out_dir: ''
# tag: ''
# suffix: ''
# DO NOT specify exp-config term
save_on: False
log_off: False
suffix_off: False
# Criterion
criterion: NLL
weights:
- 1.0 # Weight of no-change class
- 10.0 # Weight of change class
# Model
model: siamunet_conc
num_feats_in: 13
\ No newline at end of file
# Global constants
# Dataset directories
IMDB_OSCD = '~/Datasets/OSCDDataset/'
IMDB_AC = '~/Datasets/SZTAKI_AirChange_Benchmark/'
# Checkpoint templates
CKP_LATEST = 'checkpoint_latest.pth'
CKP_BEST = 'model_best.pth'
CKP_COUNTED = 'checkpoint_{e:03d}.pth'
from functools import wraps
from inspect import isfunction, isgeneratorfunction, getmembers
from collections.abc import Iterable
from itertools import chain
from importlib import import_module
import torch
import torch.nn as nn
import torch.utils.data as data
import constants
import utils.metrics as metrics
from utils.misc import R
class _Desc:
def __init__(self, key):
self.key = key
def __get__(self, instance, owner):
return tuple(getattr(instance[_],self.key) for _ in range(len(instance)))
def __set__(self, instance, values):
if not (isinstance(values, Iterable) and len(values)==len(instance)):
raise TypeError("incorrect type or number of values")
for i, v in zip(range(len(instance)), values):
setattr(instance[i], self.key, v)
def _func_deco(func_name):
def _wrapper(self, *args):
# TODO: Add key argument support
try:
# Dispatch type 1
ret = tuple(getattr(ins, func_name)(*args) for ins in self)
except Exception:
# Dispatch type 2
if len(args) > 1 or (len(args[0]) != len(self)): raise
ret = tuple(getattr(i, func_name)(a) for i, a in zip(self, args[0]))
return ret
return _wrapper
def _generator_deco(func_name):
def _wrapper(self, *args, **kwargs):
for ins in self:
yield from getattr(ins, func_name)(*args, **kwargs)
return _wrapper
# Duck typing
class Duck(tuple):
__ducktype__ = object
def __new__(cls, *args):
if any(not isinstance(a, cls.__ducktype__) for a in args):
raise TypeError("please check the input type")
return tuple.__new__(cls, args)
class DuckMeta(type):
def __new__(cls, name, bases, attrs):
assert len(bases) == 1
for k, v in getmembers(bases[0]):
if k.startswith('__'):
continue
if isgeneratorfunction(v):
attrs[k] = _generator_deco(k)
elif isfunction(v):
attrs[k] = _func_deco(k)
else:
attrs[k] = _Desc(k)
attrs['__ducktype__'] = bases[0]
return super().__new__(cls, name, (Duck,), attrs)
class DuckModel(nn.Module, metaclass=DuckMeta):
pass
class DuckOptimizer(torch.optim.Optimizer, metaclass=DuckMeta):
@property
def param_groups(self):
return list(chain.from_iterable(ins.param_groups for ins in self))
class DuckCriterion(nn.Module, metaclass=DuckMeta):
pass
class DuckDataset(data.Dataset, metaclass=DuckMeta):
pass
def _import_module(pkg: str, mod: str, rel=False):
if not rel:
# Use absolute import
return import_module('.'.join([pkg, mod]), package=None)
else:
return import_module('.'+mod, package=pkg)
def single_model_factory(model_name, C):
name = model_name.strip().upper()
if name == 'SIAMUNET_CONC':
from models.siamunet_conc import SiamUnet_conc
return SiamUnet_conc(C.num_feats_in, 2)
elif name == 'SIAMUNET_DIFF':
from models.siamunet_diff import SiamUnet_diff
return SiamUnet_diff(C.num_feats_in, 2)
else:
raise NotImplementedError("{} is not a supported architecture".format(model_name))
def single_optim_factory(optim_name, params, C):
name = optim_name.upper()
if name == 'ADAM':
return torch.optim.Adam(
params,
betas=(0.9, 0.999),
lr=C.lr,
weight_decay=C.weight_decay
)
else:
raise NotImplementedError("{} is not a supported optimizer type".format(optim_name))
def single_critn_factory(critn_name, C):
import losses
try:
criterion, params = {
'L1': (nn.L1Loss, ()),
'MSE': (nn.MSELoss, ()),
'CE': (nn.CrossEntropyLoss, (torch.Tensor(C.weights),)),
'NLL': (nn.NLLLoss, (torch.Tensor(C.weights),))
}[critn_name.upper()]
return criterion(*params)
except KeyError:
raise NotImplementedError("{} is not a supported criterion type".format(critn_name))
def single_train_ds_factory(ds_name, C):
from data.augmentation import Compose, Crop, Flip
module = _import_module('data', ds_name.strip())
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='train',
transforms=(Compose(Crop(C.crop_size), Flip()), None, None),
repeats=C.repeats
)
if ds_name == 'OSCD':
configs.update(
dict(
root = constants.IMDB_OSCD
)
)
elif ds_name == 'AC':
configs.update(
dict(
root = constants.IMDB_AC
)
)
else:
pass
dataset_obj = dataset(**configs)
return data.DataLoader(
dataset_obj,
batch_size=C.batch_size,
shuffle=True,
num_workers=C.num_workers,
pin_memory=not (C.device == 'cpu'), drop_last=True
)
def single_val_ds_factory(ds_name, C):
module = _import_module('data', ds_name.strip())
dataset = getattr(module, ds_name+'Dataset')
configs = dict(
phase='val',
transforms=(None, None, None)
)
if ds_name == 'OSCD':
configs.update(
dict(
root = constants.IMDB_OSCD
)
)
elif ds_name == 'AC':
configs.update(
dict(
root = constants.IMDB_AirChange
)
)
else:
pass
dataset_obj = dataset(**configs)
# Create eval set
return data.DataLoader(
dataset_obj,
batch_size=1,
shuffle=False,
num_workers=C.num_workers,
pin_memory=False, drop_last=False
)
def _parse_input_names(name_str):
return name_str.split('+')
def model_factory(model_names, C):
name_list = _parse_input_names(model_names)
if len(name_list) > 1:
return DuckModel(*(single_model_factory(name, C) for name in name_list))
else:
return single_model_factory(model_names, C)
def optim_factory(optim_names, params, C):
name_list = _parse_input_names(optim_names)
if len(name_list) > 1:
return DuckOptimizer(*(single_optim_factory(name, params, C) for name in name_list))
else:
return single_optim_factory(optim_names, params, C)
def critn_factory(critn_names, C):
name_list = _parse_input_names(critn_names)
if len(name_list) > 1:
return DuckCriterion(*(single_critn_factory(name, C) for name in name_list))
else:
return single_critn_factory(critn_names, C)
def data_factory(dataset_names, phase, C):
name_list = _parse_input_names(dataset_names)
if phase not in ('train', 'val'):
raise ValueError("phase should be either 'train' or 'val'")
fact = globals()['single_'+phase+'_ds_factory']
if len(name_list) > 1:
return DuckDataset(*(fact(name, C) for name in name_list))
else:
return fact(dataset_names, C)
def metric_factory(metric_names, C):
from utils import metrics
name_list = _parse_input_names(metric_names)
return [getattr(metrics, name)() for name in name_list]
import shutil
import os
from types import MappingProxyType
from copy import deepcopy
import torch
from skimage import io
from tqdm import tqdm
import constants
from data.common import to_array
from utils.misc import R
from utils.metrics import AverageMeter
from utils.utils import mod_crop
from .factories import (model_factory, optim_factory, critn_factory, data_factory, metric_factory)
class Trainer:
def __init__(self, model, dataset, criterion, optimizer, settings):
super().__init__()
context = deepcopy(settings)
self.ctx = MappingProxyType(vars(context))
self.phase = context.cmd
self.logger = R['LOGGER']
self.gpc = R['GPC'] # Global Path Controller
self.path = self.gpc.get_path
self.batch_size = context.batch_size
self.checkpoint = context.resume
self.load_checkpoint = (len(self.checkpoint)>0)
self.num_epochs = context.num_epochs
self.lr = float(context.lr)
self.save = context.save_on or context.out_dir
self.out_dir = context.out_dir
self.trace_freq = context.trace_freq
self.device = context.device
self.suffix_off = context.suffix_off
for k, v in sorted(self.ctx.items()):
self.logger.show("{}: {}".format(k,v))
self.model = model_factory(model, context)
self.model.to(self.device)
self.criterion = critn_factory(criterion, context)
self.criterion.to(self.device)
self.optimizer = optim_factory(optimizer, self.model.parameters(), context)
self.metrics = metric_factory(context.metrics, context)
self.train_loader = data_factory(dataset, 'train', context)
self.val_loader = data_factory(dataset, 'val', context)
self.start_epoch = 0
self._init_max_acc = 0.0
def train_epoch(self):
raise NotImplementedError
def validate_epoch(self, epoch=0, store=False):
raise NotImplementedError
def train(self):
if self.load_checkpoint:
self._resume_from_checkpoint()
max_acc = self._init_max_acc
best_epoch = self.get_ckp_epoch()
for epoch in range(self.start_epoch, self.num_epochs):
lr = self._adjust_learning_rate(epoch)
self.logger.show_nl("Epoch: [{0}]\tlr {1:.06f}".format(epoch, lr))
# Train for one epoch
self.train_epoch()
# Evaluate the model on validation set
self.logger.show_nl("Validate")
acc = self.validate_epoch(epoch=epoch, store=self.save)
is_best = acc > max_acc
if is_best:
max_acc = acc
best_epoch = epoch
self.logger.show_nl("Current: {:.6f} ({:03d})\tBest: {:.6f} ({:03d})\t".format(
acc, epoch, max_acc, best_epoch))
# The checkpoint saves next epoch
self._save_checkpoint(self.model.state_dict(), self.optimizer.state_dict(), max_acc, epoch+1, is_best)
def validate(self):
if self.checkpoint:
if self._resume_from_checkpoint():
self.validate_epoch(self.get_ckp_epoch(), self.save)
else:
self.logger.warning("no checkpoint assigned!")
def _adjust_learning_rate(self, epoch):
if self.ctx['lr_mode'] == 'step':
lr = self.lr * (0.5 ** (epoch // self.ctx['step']))
elif self.ctx['lr_mode'] == 'poly':
lr = self.lr * (1 - epoch / self.num_epochs) ** 1.1
elif self.ctx['lr_mode'] == 'const':
lr = self.lr
else:
raise ValueError('unknown lr mode {}'.format(self.ctx['lr_mode']))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
def _resume_from_checkpoint(self):
if not os.path.isfile(self.checkpoint):
self.logger.error("=> no checkpoint found at '{}'".format(self.checkpoint))
return False
self.logger.show("=> loading checkpoint '{}'".format(
self.checkpoint))
checkpoint = torch.load(self.checkpoint)
state_dict = self.model.state_dict()
ckp_dict = checkpoint.get('state_dict', checkpoint)
update_dict = {k:v for k,v in ckp_dict.items()
if k in state_dict and state_dict[k].shape == v.shape}
num_to_update = len(update_dict)
if (num_to_update < len(state_dict)) or (len(state_dict) < len(ckp_dict)):
if self.phase == 'val' and (num_to_update < len(state_dict)):
self.logger.error("=> mismatched checkpoint for validation")
return False
self.logger.warning("warning: trying to load an mismatched checkpoint")
if num_to_update == 0:
self.logger.error("=> no parameter is to be loaded")
return False
else:
self.logger.warning("=> {} params are to be loaded".format(num_to_update))
elif (not self.ctx['anew']) or (self.phase != 'train'):
# Note in the non-anew mode, it is not guaranteed that the contained field
# max_acc be the corresponding one of the loaded checkpoint.
self.start_epoch = checkpoint.get('epoch', self.start_epoch)
self._init_max_acc = checkpoint.get('max_acc', self._init_max_acc)
if self.ctx['load_optim']:
try:
# Note that weight decay might be modified here
self.optimizer.load_state_dict(checkpoint['optimizer'])
except KeyError:
self.logger.warning("warning: failed to load optimizer parameters")
state_dict.update(update_dict)
self.model.load_state_dict(state_dict)
self.logger.show("=> loaded checkpoint '{}' (epoch {}, max_acc {:.4f})".format(
self.checkpoint, self.get_ckp_epoch(), self._init_max_acc
))
return True
def _save_checkpoint(self, state_dict, optim_state, max_acc, epoch, is_best):
state = {
'epoch': epoch,
'state_dict': state_dict,
'optimizer': optim_state,
'max_acc': max_acc
}
# Save history
history_path = self.path('weight', constants.CKP_COUNTED.format(e=epoch), underline=True)
if epoch % self.trace_freq == 0:
torch.save(state, history_path)
# Save latest
latest_path = self.path(
'weight', constants.CKP_LATEST,
underline=True
)
torch.save(state, latest_path)
if is_best:
shutil.copyfile(
latest_path, self.path(
'weight', constants.CKP_BEST,
underline=True
)
)
def get_ckp_epoch(self):
# Get current epoch of the checkpoint
# For dismatched ckp or no ckp, set to 0
return max(self.start_epoch-1, 0)
def save_image(self, file_name, image, epoch):
file_path = os.path.join(
'epoch_{}/'.format(epoch),
self.out_dir,
file_name
)
out_path = self.path(
'out', file_path,
suffix=not self.suffix_off,
auto_make=True,
underline=True
)
return io.imsave(out_path, image)
class CDTrainer(Trainer):
def __init__(self, arch, dataset, optimizer, settings):
super().__init__(arch, dataset, 'NLL', optimizer, settings)
def train_epoch(self):
losses = AverageMeter()
len_train = len(self.train_loader)
pb = tqdm(self.train_loader)
self.model.train()
for i, (t1, t2, label) in enumerate(pb):
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
# Compute gradients and do SGD step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
desc = self.logger.make_desc(
i+1, len_train,
('loss', losses, '.4f')
)
pb.set_description(desc)
self.logger.dump(desc)
def validate_epoch(self, epoch=0, store=False):
self.logger.show_nl("Epoch: [{0}]".format(epoch))
losses = AverageMeter()
len_val = len(self.val_loader)
pb = tqdm(self.val_loader)
self.model.eval()
with torch.no_grad():
for i, (name, t1, t2, label) in enumerate(pb):
t1, t2, label = t1.to(self.device), t2.to(self.device), label.to(self.device)
prob = self.model(t1, t2)
loss = self.criterion(prob, label)
losses.update(loss.item(), n=self.batch_size)
# Convert to numpy arrays
CM = to_array(torch.argmax(prob, 1)).astype('uint8')
label = to_array(label[0]).astype('uint8')
for m in self.metrics:
m.update(CM, label)
desc = self.logger.make_desc(
i+1, len_val,
('loss', losses, '.4f'),
*(
(m.__name__, m, '.4f')
for m in self.metrics
)
)
pb.set_description(desc)
self.logger.dump(desc)
if store:
self.save_image(name[0], CM, epoch)
return self.metrics[0].avg if len(self.metrics) > 0 else max(1.0 - losses.avg, self._init_max_acc)
\ No newline at end of file
from glob import glob
from os.path import join, basename
from multiprocessing import Manager
import numpy as np
from . import CDDataset
from .common import default_loader
class OSCDDataset(CDDataset):
__BAND_NAMES = (
'B01', 'B02', 'B03', 'B04', 'B05', 'B06',
'B07', 'B08', 'B8A', 'B09', 'B10', 'B11', 'B12'
)
def __init__(
self,
root, phase='train',
transforms=(None, None, None),
repeats=1,
cache_labels=True
):
super().__init__(root, phase, transforms, repeats)
self.cache_on = cache_labels
if self.cache_on:
self._manager = Manager()
self.label_pool = self._manager.dict()
def _read_file_paths(self):
image_dir = join(self.root, 'Onera Satellite Change Detection dataset - Images')
label_dir = join(self.root, 'Onera Satellite Change Detection dataset - Train Labels')
txt_file = join(image_dir, 'train.txt')
# Read cities
with open(txt_file, 'r') as f:
cities = [city.strip() for city in f.read().strip().split(',')]
if self.phase == 'train':
# For training, use the first 10 pairs
cities = cities[:-3]
else:
# For validation, use the remaining 3 pairs
cities = cities[-3:]
# t1_list, t2_list = [], []
# for city in cities:
# t1s = glob(join(image_dir, city, 'imgs_1', '*_B??.tif'))
# t1_list.append(t1s) # Populate t1_list
# # Recognize t2 from t1
# prefix = glob(join(image_dir, city, 'imgs_2/*_B01.tif'))[0][:-5]
# t2_list.append([prefix+t1[-5:] for t1 in t1s])
#
# Use resampled images
t1_list = [[join(image_dir, city, 'imgs_1_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
t2_list = [[join(image_dir, city, 'imgs_2_rect', band+'.tif') for band in self.__BAND_NAMES] for city in cities]
label_list = [join(label_dir, city, 'cm', city+'-cm.tif') for city in cities]
return t1_list, t2_list, label_list
def fetch_image(self, image_paths):
return np.stack([default_loader(p) for p in image_paths], axis=-1).astype(np.float32)
def fetch_label(self, label_path):
if self.cache_on:
label = self.label_pool.get(label_path, None)
if label is not None:
return label
# In the tif labels, 1 for NC and 2 for C
# Thus a -1 offset is needed
label = default_loader(label_path) - 1
if self.cache_on:
self.label_pool[label_path] = label
return label
\ No newline at end of file
from os.path import join, expanduser, basename
import torch
import torch.utils.data as data
import numpy as np
from .common import (default_loader, to_tensor)
class CDDataset(data.Dataset):
def __init__(
self,
root, phase,
transforms,
repeats
):
super().__init__()
self.root = expanduser(root)
self.phase = phase
self.transforms = transforms
self.repeats = repeats
self.t1_list, self.t2_list, self.label_list = self._read_file_paths()
self.len = len(self.label_list)
def __len__(self):
return self.len * self.repeats
def __getitem__(self, index):
if index >= len(self):
raise IndexError
index = index % self.len
t1 = self.fetch_image(self.t1_list[index])
t2 = self.fetch_image(self.t2_list[index])
label = self.fetch_label(self.label_list[index])
t1, t2, label = self.preprocess(t1, t2, label)
if self.phase == 'train':
return t1, t2, label
else:
return basename(self.label_list[index]), t1, t2, label
def _read_file_paths(self):
raise NotImplementedError
def fetch_label(self, label_path):
return default_loader(label_path)
def fetch_image(self, image_path):
return default_loader(image_path)
def preprocess(self, t1, t2, label):
if self.transforms[0] is not None:
# Applied on all
t1, t2, label = self.transforms[0](t1, t2, label)
if self.transforms[1] is not None:
# For images solely
t1, t2 = self.transforms[1](t1, t2)
if self.transforms[2] is not None:
# For labels solely
label = self.transforms[2](label)
return to_tensor(t1).float(), to_tensor(t2).float(), to_tensor(label).long()
\ No newline at end of file
import random
from functools import partial, wraps
import numpy as np
import cv2
rand = random.random
randi = random.randint
choice = random.choice
uniform = random.uniform
# gauss = random.gauss
gauss = random.normalvariate # This one is thread-safe
# The transformations treat numpy ndarrays only
def _istuple(x): return isinstance(x, (tuple, list))
class Transform:
def __init__(self, random_state=False):
self.random_state = random_state
def _transform(self, x):
raise NotImplementedError
def __call__(self, *args):
if self.random_state: self._set_rand_param()
assert len(args) > 0
return self._transform(args[0]) if len(args) == 1 else tuple(map(self._transform, args))
def _set_rand_param(self):
raise NotImplementedError
class Compose:
def __init__(self, *tf):
assert len(tf) > 0
self.tfs = tf
def __call__(self, *x):
if len(x) > 1:
for tf in self.tfs: x = tf(*x)
else:
x = x[0]
for tf in self.tfs: x = tf(x)
return x
class Scale(Transform):
def __init__(self, scale=(0.5,1.0)):
if _istuple(scale):
assert len(scale) == 2
self.scale_range = scale #sorted(scale)
self.scale = scale[0]
super(Scale, self).__init__(random_state=True)
else:
super(Scale, self).__init__(random_state=False)
self.scale = scale
def _transform(self, x):
# assert x.ndim == 3
h, w = x.shape[:2]
size = (int(h*self.scale), int(w*self.scale))
if size == (h,w):
return x
interp = cv2.INTER_LINEAR if np.issubdtype(x.dtype, np.floating) else cv2.INTER_NEAREST
return cv2.resize(x, size, interpolation=interp)
def _set_rand_param(self):
self.scale = uniform(*self.scale_range)
class DiscreteScale(Scale):
def __init__(self, bins=(0.5, 0.75), keep_prob=0.5):
super(DiscreteScale, self).__init__(scale=(min(bins), 1.0))
self.bins = bins
self.keep_prob = keep_prob
def _set_rand_param(self):
self.scale = 1.0 if rand()<self.keep_prob else choice(self.bins)
class Flip(Transform):
# Flip or rotate
_directions = ('ud', 'lr', 'no', '90', '180', '270')
def __init__(self, direction=None):
super(Flip, self).__init__(random_state=(direction is None))
self.direction = direction
if direction is not None: assert direction in self._directions
def _transform(self, x):
if self.direction == 'ud':
## Current torch version doesn't support negative stride of numpy arrays
return np.ascontiguousarray(x[::-1])
elif self.direction == 'lr':
return np.ascontiguousarray(x[:,::-1])
elif self.direction == 'no':
return x
elif self.direction == '90':
# Clockwise
return np.ascontiguousarray(self._T(x)[:,::-1])
elif self.direction == '180':
return np.ascontiguousarray(x[::-1,::-1])
elif self.direction == '270':
return np.ascontiguousarray(self._T(x)[::-1])
else:
raise ValueError('invalid flipping direction')
def _set_rand_param(self):
self.direction = choice(self._directions)
@staticmethod
def _T(x):
return np.swapaxes(x, 0, 1)
class HorizontalFlip(Flip):
_directions = ('lr', 'no')
def __init__(self, flip=None):
if flip is not None: flip = self._directions[~flip]
super(HorizontalFlip, self).__init__(direction=flip)
class VerticalFlip(Flip):
_directions = ('ud', 'no')
def __init__(self, flip=None):
if flip is not None: flip = self._directions[~flip]
super(VerticalFlip, self).__init__(direction=flip)
class Crop(Transform):
_inner_bounds = ('bl', 'br', 'tl', 'tr', 't', 'b', 'l', 'r')
def __init__(self, crop_size=None, bounds=None):
__no_bounds = (bounds is None)
super(Crop, self).__init__(random_state=__no_bounds)
if __no_bounds:
assert crop_size is not None
else:
if not((_istuple(bounds) and len(bounds)==4) or (isinstance(bounds, str) and bounds in self._inner_bounds)):
raise ValueError('invalid bounds')
self.bounds = bounds
self.crop_size = crop_size if _istuple(crop_size) else (crop_size, crop_size)
def _transform(self, x):
h, w = x.shape[:2]
if self.bounds == 'bl':
return x[h//2:,:w//2]
elif self.bounds == 'br':
return x[h//2:,w//2:]
elif self.bounds == 'tl':
return x[:h//2,:w//2]
elif self.bounds == 'tr':
return x[:h//2,w//2:]
elif self.bounds == 't':
return x[:h//2]
elif self.bounds == 'b':
return x[h//2:]
elif self.bounds == 'l':
return x[:,:w//2]
elif self.bounds == 'r':
return x[:,w//2:]
elif len(self.bounds) == 2:
assert self.crop_size < (h, w)
ch, cw = self.crop_size
cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1])
return x[cy:cy+ch, cx:cx+cw]
else:
left, top, right, lower = self.bounds
return x[top:lower, left:right]
def _set_rand_param(self):
self.bounds = (rand(), rand())
class MSCrop(Crop):
def __init__(self, scale, crop_size=None):
super(MSCrop, self).__init__(crop_size)
self.scale = scale # Scale factor
def __call__(self, lr, hr):
if self.random_state:
self._set_rand_param()
# I've noticed that random scaling bounds may cause pixel misalignment
# between the lr-hr pair, which significantly damages the training
# effect, thus the quadruple mode is desired
left, top, cw, ch = self._get_quad(*lr.shape[:2])
self._set_quad(left, top, cw, ch)
lr_crop = self._transform(lr)
left, top, cw, ch = [int(it*self.scale) for it in (left, top, cw, ch)]
self._set_quad(left, top, cw, ch)
hr_crop = self._transform(hr)
return lr_crop, hr_crop
def _get_quad(self, h, w):
ch, cw = self.crop_size
cx, cy = int((w-cw+1)*self.bounds[0]), int((h-ch+1)*self.bounds[1])
return cx, cy, cw, ch
def _set_quad(self, left, top, cw, ch):
self.bounds = (left, top, left+cw, top+ch)
# Color jittering and transformation
# The followings partially refer to https://github.com/albu/albumentations/
class _ValueTransform(Transform):
def __init__(self, rs, limit=(0, 255)):
super().__init__(rs)
self.limit = limit
self.limit_range = limit[1] - limit[0]
@staticmethod
def keep_range(tf):
@wraps(tf)
def wrapper(obj, x):
# # Make a copy
# x = x.copy()
x = tf(obj, np.clip(x, *obj.limit))
return np.clip(x, *obj.limit)
return wrapper
class ColorJitter(_ValueTransform):
_channel = (0,1,2)
def __init__(self, shift=((-20,20), (-20,20), (-20,20)), limit=(0,255)):
super().__init__(False, limit)
_nc = len(self._channel)
if _nc == 1:
if _istuple(shift):
rs = True
self.shift = self.range = shift
else:
rs = False
self.shift = (shift,)
self.range = (shift, shift)
else:
if _istuple(shift):
if len(shift) != _nc:
raise ValueError("specify the shift value (or range) for every channel")
rs = all(_istuple(s) for s in shift)
self.shift = self.range = shift
else:
rs = False
self.shift = [shift for _ in range(_nc)]
self.range = [(shift, shift) for _ in range(_nc)]
self.random_state = rs
def _(x):
return x, ()
self.convert_to = _
self.convert_back = _
@_ValueTransform.keep_range
def _transform(self, x):
# CAUTION!
# Type conversion here
x, params = self.convert_to(x)
for i, c in enumerate(self._channel):
x[...,c] += self.shift[i]
x[...,c] = self._clip(x[...,c])
x, _ = self.convert_back(x, *params)
return x
def _clip(self, x):
return np.clip(x, *self.limit)
def _set_rand_param(self):
if len(self._channel) == 1:
self.shift = [uniform(*self.range)]
else:
self.shift = [uniform(*r) for r in self.range]
class HSVShift(ColorJitter):
def __init__(self, shift, limit):
super().__init__(shift, limit)
def _convert_to(x):
type_x = x.dtype
x = x.astype(np.float32)
# Normalize to [0,1]
x -= self.limit[0]
x /= self.limit_range
x = cv2.cvtColor(x, code=cv2.COLOR_RGB2HSV)
return x, (type_x,)
def _convert_back(x, type_x):
x = cv2.cvtColor(x.astype(np.float32), code=cv2.COLOR_HSV2RGB)
return x.astype(type_x) * self.limit_range + self.limit[0], ()
# Pack conversion methods
self.convert_to = _convert_to
self.convert_back = _convert_back
class HueShift(HSVShift):
_channel = (0,)
def __init__(self, shift=(-20, 20), limit=(0, 255)):
super().__init__(shift, limit)
def _clip(self, x):
# Circular
# Note that this works in Opencv 3.4.3, not yet tested under other versions
x[x<0] += 360
x[x>360] -= 360
return x
class SaturationShift(HSVShift):
_channel = (1,)
def __init__(self, shift=(-30, 30), limit=(0, 255)):
super().__init__(shift, limit)
self.range = tuple(r / self.limit_range for r in self.range)
def _clip(self, x):
return np.clip(x, 0, 1.0)
class RGBShift(ColorJitter):
def __init__(self, shift=((-20,20), (-20,20), (-20,20)), limit=(0, 255)):
super().__init__(shift, limit)
class RShift(RGBShift):
_channel = (0,)
def __init__(self, shift=(-20,20), limit=(0, 255)):
super().__init__(shift, limit)
class GShift(RGBShift):
_channel = (1,)
def __init__(self, shift=(-20,20), limit=(0, 255)):
super().__init__(shift, limit)
class BShift(RGBShift):
_channel = (2,)
def __init__(self, shift=(-20,20), limit=(0, 255)):
super().__init__(shift, limit)
class PCAJitter(_ValueTransform):
def __init__(self, sigma=0.3, limit=(0, 255)):
# For RGB only
super().__init__(True, limit)
self.sigma = sigma
@_ValueTransform.keep_range
def _transform(self, x):
old_shape = x.shape
x = np.reshape(x, (-1,3), order='F') # For RGB
x_mean = np.mean(x, 0)
x -= x_mean
cov_x = np.cov(x, rowvar=False)
eig_vals, eig_vecs = np.linalg.eig(np.mat(cov_x))
# The eigen vectors are already unit "length"
noise = (eig_vals * self.alpha) * eig_vecs
x += np.asarray(noise)
return np.reshape(x+x_mean, old_shape, order='F')
def _set_rand_param(self):
self.alpha = [gauss(0, self.sigma) for _ in range(3)]
class ContraBrightScale(_ValueTransform):
def __init__(self, alpha=(-0.2, 0.2), beta=(-0.2, 0.2), limit=(0, 255)):
super().__init__(_istuple(alpha) or _istuple(beta), limit)
self.alpha = alpha
self.alpha_range = alpha if _istuple(alpha) else (alpha, alpha)
self.beta = beta
self.beta_range = beta if _istuple(beta) else (beta, beta)
@_ValueTransform.keep_range
def _transform(self, x):
if self.alpha != 1:
x *= self.alpha
if self.beta != 0:
x += self.beta*np.mean(x)
return x
def _set_rand_param(self):
self.alpha = uniform(*self.alpha_range)
self.beta = uniform(*self.beta_range)
class ContrastScale(ContraBrightScale):
def __init__(self, alpha=(0.2, 0.8), limit=(0,255)):
super().__init__(alpha=alpha, beta=0, limit=limit)
class BrightnessScale(ContraBrightScale):
def __init__(self, beta=(-0.2, 0.2), limit=(0,255)):
super().__init__(alpha=1, beta=beta, limit=limit)
class _AddNoise(_ValueTransform):
def __init__(self, limit):
super().__init__(True, limit)
self._im_shape = (0, 0)
@_ValueTransform.keep_range
def _transform(self, x):
return x + self.noise_map
def __call__(self, *args):
shape = args[0].shape
if any(im.shape != shape for im in args):
raise ValueError("the input images should be of same size")
self._im_shape = shape
return super().__call__(*args)
class AddGaussNoise(_AddNoise):
def __init__(self, mu=0.0, sigma=0.1, limit=(0, 255)):
super().__init__(limit)
self.mu = mu
self.sigma = sigma
def _set_rand_param(self):
self.noise_map = np.random.randn(*self._im_shape)*self.sigma + self.mu
def __test():
a = np.arange(12).reshape((2,2,3)).astype(np.float64)
tf = Compose(BrightnessScale(), AddGaussNoise(), HueShift())
print(a[...,0])
c = tf(a)
print(c[...,0])
print(a[...,0])
if __name__ == '__main__':
__test()
import torch
import numpy as np
from scipy.io import loadmat
from skimage.io import imread
def default_loader(path_):
return imread(path_)
def mat_loader(path_):
return loadmat(path_)
def make_onehot(index_map, n):
# Only deals with tensors with no batch dim
old_size = index_map.size()
z = torch.zeros(n, *old_size[-2:]).type_as(index_map)
z.scatter_(0, index_map, 1)
return z
def to_tensor(arr):
if arr.ndim < 3:
return torch.from_numpy(arr)
elif arr.ndim == 3:
return torch.from_numpy(np.ascontiguousarray(np.transpose(arr, (2,0,1))))
else:
raise NotImplementedError
def to_array(tensor):
if tensor.ndimension() < 3:
return tensor.data.cpu().numpy()
elif tensor.ndimension() in (3, 4):
return np.ascontiguousarray(np.moveaxis(tensor.data.cpu().numpy(), -3, -1))
else:
raise NotImplementedError
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
Subproject commit f3a8e382e3446ca2daca752c7e8f22f30291afff
#!/usr/bin/env python3
import argparse
import os
import shutil
import random
import ast
from os.path import basename, exists
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import yaml
from core.trainers import CDTrainer
from utils.misc import OutPathGetter, Logger, register
def read_config(config_path):
f = open(config_path, 'r')
cfg = yaml.load(f.read(), Loader=yaml.FullLoader)
f.close()
return cfg or {}
def parse_config(cfg_name, cfg):
# Parse the name of config file
sp = cfg_name.split('.')[0].split('_')
if len(sp) >= 2:
cfg.setdefault('tag', sp[1])
cfg.setdefault('suffix', '_'.join(sp[2:]))
return cfg
def parse_args():
# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('cmd', choices=['train', 'val'])
# Data
# Common
group_data = parser.add_argument_group('data')
group_data.add_argument('-d', '--dataset', type=str, default='OSCD')
group_data.add_argument('-p', '--crop-size', type=int, default=256, metavar='P',
help='patch size (default: %(default)s)')
group_data.add_argument('--num-workers', type=int, default=8)
group_data.add_argument('--repeats', type=int, default=100)
# Optimizer
group_optim = parser.add_argument_group('optimizer')
group_optim.add_argument('--optimizer', type=str, default='Adam')
group_optim.add_argument('--lr', type=float, default=1e-4, metavar='LR',
help='learning rate (default: %(default)s)')
group_optim.add_argument('--lr-mode', type=str, default='const')
group_optim.add_argument('--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: %(default)s)')
group_optim.add_argument('--step', type=int, default=200)
# Training related
group_train = parser.add_argument_group('training related')
group_train.add_argument('--batch-size', type=int, default=8, metavar='B',
help='input batch size for training (default: %(default)s)')
group_train.add_argument('--num-epochs', type=int, default=1000, metavar='NE',
help='number of epochs to train (default: %(default)s)')
group_train.add_argument('--load-optim', action='store_true')
group_train.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint')
group_train.add_argument('--anew', action='store_true',
help='clear history and start from epoch 0 with the checkpoint loaded')
group_train.add_argument('--trace-freq', type=int, default=50)
group_train.add_argument('--device', type=str, default='cpu')
group_train.add_argument('--metrics', type=str, default='F1Score+Accuracy+Recall+Precision')
# Experiment
group_exp = parser.add_argument_group('experiment related')
group_exp.add_argument('--exp-dir', default='../exp/')
group_exp.add_argument('-o', '--out-dir', default='')
group_exp.add_argument('--tag', type=str, default='')
group_exp.add_argument('--suffix', type=str, default='')
group_exp.add_argument('--exp-config', type=str, default='')
group_exp.add_argument('--save-on', action='store_true')
group_exp.add_argument('--log-off', action='store_true')
group_exp.add_argument('--suffix-off', action='store_true')
# Criterion
group_critn = parser.add_argument_group('criterion related')
group_critn.add_argument('--criterion', type=str, default='NLL')
group_critn.add_argument('--weights', type=str, default=(1.0, 1.0))
# Model
group_model = parser.add_argument_group('model')
group_model.add_argument('--model', type=str, default='siamunet_conc')
group_model.add_argument('--num-feats-in', type=int, default=13)
args = parser.parse_args()
if exists(args.exp_config):
cfg = read_config(args.exp_config)
cfg = parse_config(basename(args.exp_config), cfg)
# Settings from cfg file overwrite those in args
# Note that the non-default values will not be affected
parser.set_defaults(**cfg) # Reset part of the default values
args = parser.parse_args() # Parse again
# Handle args.weights
if isinstance(args.weights, str):
args.weights = ast.literal_eval(args.weights)
args.weights = tuple(args.weights)
return args
def set_gpc_and_logger(args):
gpc = OutPathGetter(
root=os.path.join(args.exp_dir, args.tag),
suffix=args.suffix)
log_dir = '' if args.log_off else gpc.get_dir('log')
logger = Logger(
scrn=True,
log_dir=log_dir,
phase=args.cmd
)
register('GPC', gpc)
register('LOGGER', logger)
return gpc, logger
def main():
args = parse_args()
gpc, logger = set_gpc_and_logger(args)
if exists(args.exp_config):
# Make a copy of the config file
cfg_path = gpc.get_path('root', basename(args.exp_config), suffix=False)
shutil.copy(args.exp_config, cfg_path)
# Set random seed
RNG_SEED = 1
random.seed(RNG_SEED)
np.random.seed(RNG_SEED)
torch.manual_seed(RNG_SEED)
cudnn.deterministic = True
cudnn.benchmark = False
try:
trainer = CDTrainer(args.model, args.dataset, args.optimizer, args)
if args.cmd == 'train':
trainer.train()
elif args.cmd == 'val':
trainer.validate()
else:
pass
except BaseException as e:
import traceback
# Catch ALL kinds of exceptions
logger.error(traceback.format_exc())
exit(1)
if __name__ == '__main__':
main()
\ No newline at end of file
from sklearn import metrics
class AverageMeter:
def __init__(self, callback=None):
super().__init__()
self.callback = callback
self.reset()
def compute(self, *args):
if self.callback is not None:
return self.callback(*args)
elif len(args) == 1:
return args[0]
else:
raise NotImplementedError
def reset(self):
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0
def update(self, *args, n=1):
self.val = self.compute(*args)
self.sum += self.val * n
self.count += n
self.avg = self.sum / self.count
class Metric(AverageMeter):
__name__ = 'Metric'
def __init__(self, callback, **configs):
super().__init__(callback)
self.configs = configs
def compute(self, pred, true):
return self.callback(true.ravel(), pred.ravel(), **self.configs)
class Precision(Metric):
__name__ = 'Prec.'
def __init__(self, **configs):
super().__init__(metrics.precision_score, **configs)
class Recall(Metric):
__name__ = 'Recall'
def __init__(self, **configs):
super().__init__(metrics.recall_score, **configs)
class Accuracy(Metric):
__name__ = 'OA'
def __init__(self, **configs):
super().__init__(metrics.accuracy_score, **configs)
class F1Score(Metric):
__name__ = 'F1'
def __init__(self, **configs):
super().__init__(metrics.f1_score, **configs)
\ No newline at end of file
import logging
import os
from time import localtime
from collections import OrderedDict
from weakref import proxy
FORMAT_LONG = "[%(asctime)-15s %(funcName)s] %(message)s"
FORMAT_SHORT = "%(message)s"
class Logger:
_count = 0
def __init__(self, scrn=True, log_dir='', phase=''):
super().__init__()
self._logger = logging.getLogger('logger_{}'.format(Logger._count))
Logger._count += 1
self._logger.setLevel(logging.DEBUG)
if scrn:
self._scrn_handler = logging.StreamHandler()
self._scrn_handler.setLevel(logging.INFO)
self._scrn_handler.setFormatter(logging.Formatter(fmt=FORMAT_SHORT))
self._logger.addHandler(self._scrn_handler)
if log_dir and phase:
self.log_path = os.path.join(log_dir,
'{}-{:-4d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}.log'.format(
phase, *localtime()[:6]
))
self.show_nl("log into {}\n\n".format(self.log_path))
self._file_handler = logging.FileHandler(filename=self.log_path)
self._file_handler.setLevel(logging.DEBUG)
self._file_handler.setFormatter(logging.Formatter(fmt=FORMAT_LONG))
self._logger.addHandler(self._file_handler)
def show(self, *args, **kwargs):
return self._logger.info(*args, **kwargs)
def show_nl(self, *args, **kwargs):
self._logger.info("")
return self.show(*args, **kwargs)
def dump(self, *args, **kwargs):
return self._logger.debug(*args, **kwargs)
def warning(self, *args, **kwargs):
return self._logger.warning(*args, **kwargs)
def error(self, *args, **kwargs):
return self._logger.error(*args, **kwargs)
@staticmethod
def make_desc(counter, total, *triples):
desc = "[{}/{}]".format(counter, total)
# The three elements of each triple are
# (name to display, AverageMeter object, formatting string)
for name, obj, fmt in triples:
desc += (" {} {obj.val:"+fmt+"} ({obj.avg:"+fmt+"})").format(name, obj=obj)
return desc
_default_logger = Logger()
class _WeakAttribute:
def __get__(self, instance, owner):
return instance.__dict__[self.name]
def __set__(self, instance, value):
if value is not None:
value = proxy(value)
instance.__dict__[self.name] = value
def __set_name__(self, owner, name):
self.name = name
class _TreeNode:
_sep = '/'
_none = None
parent = _WeakAttribute() # To avoid circular reference
def __init__(self, name, value=None, parent=None, children=None):
super().__init__()
self.name = name
self.val = value
self.parent = parent
self.children = children if isinstance(children, dict) else {}
if isinstance(children, list):
for child in children:
self._add_child(child)
self.path = name
def get_child(self, name, def_val=None):
return self.children.get(name, def_val)
def set_child(self, name, val=None):
r"""
Set the value of an existing node.
If the node does not exist, return nothing
"""
child = self.get_child(name)
if child is not None:
child.val = val
return child
def add_place_holder(self, name):
return self.add_child(name, val=self._none)
def add_child(self, name, val):
r"""
If not exists or is a placeholder, create it
Otherwise skips and returns the existing node
"""
child = self.get_child(name, None)
if child is None:
child = _TreeNode(name, val, parent=self)
self._add_child(child)
elif child.val == self._none:
# Retain the links of the placeholder
# i.e. just fill in it
child.val = val
return child
def is_leaf(self):
return len(self.children) == 0
def __repr__(self):
try:
repr = self.path + ' ' + str(self.val)
except TypeError:
repr = self.path
return repr
def __contains__(self, name):
return name in self.children.keys()
def __getitem__(self, key):
return self.get_child(key)
def _add_child(self, node):
r""" Into children dictionary and set path and parent """
self.children.update({
node.name: node
})
node.path = self._sep.join([self.path, node.name])
node.parent = self
def apply(self, func):
r"""
Apply a callback function on ALL descendants
This is useful for the recursive traversal
"""
ret = [func(self)]
for _, node in self.children.items():
ret.extend(node.apply(func))
return ret
def bfs_tracker(self):
queue = []
queue.insert(0, self)
while(queue):
curr = queue.pop()
yield curr
if curr.is_leaf():
continue
for c in curr.children.values():
queue.insert(0, c)
class _Tree:
def __init__(
self, name, value=None, strc_ele=None,
sep=_TreeNode._sep, def_val=_TreeNode._none
):
super().__init__()
self._sep = sep
self._def_val = def_val
self.root = _TreeNode(name, value, parent=None, children={})
if strc_ele is not None:
assert isinstance(strc_ele, dict)
# This is to avoid mutable parameter default
self.build_tree(OrderedDict(strc_ele or {}))
def build_tree(self, elements):
# The siblings could be out-of-order
for path, ele in elements.items():
self.add_node(path, ele)
def get_root(self):
r""" Get separated root node """
return _TreeNode(
self.root.name, self.root.value,
parent=None, children=None
)
def __repr__(self):
return self.__dumps__()
def __dumps__(self):
r""" Dump to string """
_str = ''
# DFS
stack = []
stack.append((self.root, 0))
while(stack):
root, layer = stack.pop()
_str += ' '*layer + '-' + root.__repr__() + '\n'
if root.is_leaf():
continue
# Note that the order of the siblings is not retained
for c in reversed(list(root.children.values())):
stack.append((c, layer+1))
return _str
def vis(self):
r""" Visualize the structure of the tree """
_default_logger.show(self.__dumps__())
def __contains__(self, obj):
return any(self.perform(lambda node: obj in node))
def perform(self, func):
return self.root.apply(func)
def get_node(self, tar, mode='name'):
r"""
This is different from the travasal in that
the search allows early stop
"""
if mode == 'path':
nodes = self.parse_path(tar)
root = self.root
for r in nodes:
if root is None:
root = root.get_child(r)
return root
else:
# BFS
bfs_tracker = self.root.bfs_tracker()
# bfs_tracker.send(None)
for node in bfs_tracker:
if getattr(node, mode) == tar:
return node
return
def set_node(self, path, val):
node = self.get_node(path, mode=path)
if node is not None:
node.val = val
return node
def add_node(self, path, val=None):
if not path.strip():
raise ValueError("the path is null")
if val is None:
val = self._def_val
names = self.parse_path(path)
root = self.root
nodes = [root]
for name in names[:-1]:
# Add placeholders
root = root.add_child(name, self._def_val)
nodes.append(root)
root = root.add_child(names[-1], val)
return root, nodes
def parse_path(self, path):
return path.split(self._sep)
def join(self, *args):
return self._sep.join(args)
class OutPathGetter:
def __init__(self, root='', log='logs', out='outs', weight='weights', suffix='', **subs):
super().__init__()
self._root = root.rstrip('/') # Work robustly for multiple ending '/'s
self._suffix = suffix
self._keys = dict(log=log, out=out, weight=weight, **subs)
self._dir_tree = _Tree(
self._root, 'root',
strc_ele=dict(zip(self._keys.values(), self._keys.keys())),
sep='/',
def_val=''
)
self.update_keys(False)
self.update_tree(False)
self.__counter = 0
def __str__(self):
return '\n'+self.sub_dirs
@property
def sub_dirs(self):
return str(self._dir_tree)
@property
def root(self):
return self._root
def _update_key(self, key, val, add=False, prefix=False):
if prefix:
val = os.path.join(self._root, val)
if add:
# Do not edit if exists
self._keys.setdefault(key, val)
else:
self._keys.__setitem__(key, val)
def _add_node(self, key, val, prefix=False):
if not prefix and key.startswith(self._root):
key = key[len(self._root)+1:]
return self._dir_tree.add_node(key, val)
def update_keys(self, verbose=False):
for k, v in self._keys.items():
self._update_key(k, v, prefix=True)
if verbose:
_default_logger.show(self._keys)
def update_tree(self, verbose=False):
self._dir_tree.perform(lambda x: self.make_dir(x.path))
if verbose:
_default_logger.show("\nFolder structure:")
_default_logger.show(self._dir_tree)
@staticmethod
def make_dir(path):
if not os.path.exists(path):
os.mkdir(path)
def get_dir(self, key):
return self._keys.get(key, '') if key != 'root' else self.root
def get_path(
self, key, file,
name='', auto_make=False,
suffix=True, underline=False
):
folder = self.get_dir(key)
if len(folder) < 1:
raise KeyError("key not found")
if suffix:
path = os.path.join(folder, self.add_suffix(file, underline=underline))
else:
path = os.path.join(folder, file)
if auto_make:
base_dir = os.path.dirname(path)
if base_dir in self:
return path
if name:
self._update_key(name, base_dir, add=True)
'''
else:
name = 'new_{:03d}'.format(self.__counter)
self._update_key(name, base_dir, add=True)
self.__counter += 1
'''
des, visit = self._add_node(base_dir, name)
# Create directories along the visiting path
for d in visit: self.make_dir(d.path)
self.make_dir(des.path)
return path
def add_suffix(self, path, suffix='', underline=False):
pos = path.rfind('.')
if pos == -1:
pos = len(path)
_suffix = self._suffix if len(suffix) < 1 else suffix
return path[:pos] + ('_' if underline and _suffix else '') + _suffix + path[pos:]
def __contains__(self, value):
return value in self._keys.values()
class Registry(dict):
def register(self, key, val):
if key in self: _default_logger.warning("key {} already registered".format(key))
self[key] = val
R = Registry()
R.register('DEFAULT_LOGGER', _default_logger)
register = R.register
\ No newline at end of file
import math
import torch
import numpy as np
def mod_crop(blob, N):
if isinstance(blob, np.ndarray):
# For numpy arrays, channels at the last dim
h, w = blob.shape[-3:-1]
nh = h - h % N
nw = w - w % N
return blob[..., :nh, :nw, :]
else:
# For 4-D pytorch tensors, channels at the 2nd dim
with torch.no_grad():
h, w = blob.shape[-2:]
nh = h - h % N
nw = w - w % N
return blob[..., :nh, :nw]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment