generated from thinkode/modelRepository
Initial commit and v1.0
This commit is contained in:
405
demucs/solver.py
Normal file
405
demucs/solver.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Main training loop."""
|
||||
|
||||
import logging
|
||||
|
||||
from dora import get_xp
|
||||
from dora.utils import write_and_rename
|
||||
from dora.log import LogProgress, bold
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import augment, distrib, states, pretrained
|
||||
from .apply import apply_model
|
||||
from .ema import ModelEMA
|
||||
from .evaluate import evaluate, new_sdr
|
||||
from .svd import svd_penalty
|
||||
from .utils import pull_metric, EMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _summary(metrics):
|
||||
return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items())
|
||||
|
||||
|
||||
class Solver(object):
|
||||
def __init__(self, loaders, model, optimizer, args):
|
||||
self.args = args
|
||||
self.loaders = loaders
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer)
|
||||
self.dmodel = distrib.wrap(model)
|
||||
self.device = next(iter(self.model.parameters())).device
|
||||
|
||||
# Exponential moving average of the model, either updated every batch or epoch.
|
||||
# The best model from all the EMAs and the original one is kept based on the valid
|
||||
# loss for the final best model.
|
||||
self.emas = {'batch': [], 'epoch': []}
|
||||
for kind in self.emas.keys():
|
||||
decays = getattr(args.ema, kind)
|
||||
device = self.device if kind == 'batch' else 'cpu'
|
||||
if decays:
|
||||
for decay in decays:
|
||||
self.emas[kind].append(ModelEMA(self.model, decay, device=device))
|
||||
|
||||
# data augment
|
||||
augments = [augment.Shift(shift=int(args.dset.samplerate * args.dset.shift),
|
||||
same=args.augment.shift_same)]
|
||||
if args.augment.flip:
|
||||
augments += [augment.FlipChannels(), augment.FlipSign()]
|
||||
for aug in ['scale', 'remix']:
|
||||
kw = getattr(args.augment, aug)
|
||||
if kw.proba:
|
||||
augments.append(getattr(augment, aug.capitalize())(**kw))
|
||||
self.augment = torch.nn.Sequential(*augments)
|
||||
|
||||
xp = get_xp()
|
||||
self.folder = xp.folder
|
||||
# Checkpoints
|
||||
self.checkpoint_file = xp.folder / 'checkpoint.th'
|
||||
self.best_file = xp.folder / 'best.th'
|
||||
logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve())
|
||||
self.best_state = None
|
||||
self.best_changed = False
|
||||
|
||||
self.link = xp.link
|
||||
self.history = self.link.history
|
||||
|
||||
self._reset()
|
||||
|
||||
def _serialize(self, epoch):
|
||||
package = {}
|
||||
package['state'] = self.model.state_dict()
|
||||
package['optimizer'] = self.optimizer.state_dict()
|
||||
package['history'] = self.history
|
||||
package['best_state'] = self.best_state
|
||||
package['args'] = self.args
|
||||
for kind, emas in self.emas.items():
|
||||
for k, ema in enumerate(emas):
|
||||
package[f'ema_{kind}_{k}'] = ema.state_dict()
|
||||
with write_and_rename(self.checkpoint_file) as tmp:
|
||||
torch.save(package, tmp)
|
||||
|
||||
save_every = self.args.save_every
|
||||
if save_every and (epoch + 1) % save_every == 0 and epoch + 1 != self.args.epochs:
|
||||
with write_and_rename(self.folder / f'checkpoint_{epoch + 1}.th') as tmp:
|
||||
torch.save(package, tmp)
|
||||
|
||||
if self.best_changed:
|
||||
# Saving only the latest best model.
|
||||
with write_and_rename(self.best_file) as tmp:
|
||||
package = states.serialize_model(self.model, self.args)
|
||||
package['state'] = self.best_state
|
||||
torch.save(package, tmp)
|
||||
self.best_changed = False
|
||||
|
||||
def _reset(self):
|
||||
"""Reset state of the solver, potentially using checkpoint."""
|
||||
if self.checkpoint_file.exists():
|
||||
logger.info(f'Loading checkpoint model: {self.checkpoint_file}')
|
||||
package = torch.load(self.checkpoint_file, 'cpu')
|
||||
self.model.load_state_dict(package['state'])
|
||||
self.optimizer.load_state_dict(package['optimizer'])
|
||||
self.history[:] = package['history']
|
||||
self.best_state = package['best_state']
|
||||
for kind, emas in self.emas.items():
|
||||
for k, ema in enumerate(emas):
|
||||
ema.load_state_dict(package[f'ema_{kind}_{k}'])
|
||||
elif self.args.continue_pretrained:
|
||||
model = pretrained.get_model(
|
||||
name=self.args.continue_pretrained,
|
||||
repo=self.args.pretrained_repo)
|
||||
self.model.load_state_dict(model.state_dict())
|
||||
elif self.args.continue_from:
|
||||
name = 'checkpoint.th'
|
||||
root = self.folder.parent
|
||||
cf = root / str(self.args.continue_from) / name
|
||||
logger.info("Loading from %s", cf)
|
||||
package = torch.load(cf, 'cpu')
|
||||
self.best_state = package['best_state']
|
||||
if self.args.continue_best:
|
||||
self.model.load_state_dict(package['best_state'], strict=False)
|
||||
else:
|
||||
self.model.load_state_dict(package['state'], strict=False)
|
||||
if self.args.continue_opt:
|
||||
self.optimizer.load_state_dict(package['optimizer'])
|
||||
|
||||
def _format_train(self, metrics: dict) -> dict:
|
||||
"""Formatting for train/valid metrics."""
|
||||
losses = {
|
||||
'loss': format(metrics['loss'], ".4f"),
|
||||
'reco': format(metrics['reco'], ".4f"),
|
||||
}
|
||||
if 'nsdr' in metrics:
|
||||
losses['nsdr'] = format(metrics['nsdr'], ".3f")
|
||||
if self.quantizer is not None:
|
||||
losses['ms'] = format(metrics['ms'], ".2f")
|
||||
if 'grad' in metrics:
|
||||
losses['grad'] = format(metrics['grad'], ".4f")
|
||||
if 'best' in metrics:
|
||||
losses['best'] = format(metrics['best'], '.4f')
|
||||
if 'bname' in metrics:
|
||||
losses['bname'] = metrics['bname']
|
||||
if 'penalty' in metrics:
|
||||
losses['penalty'] = format(metrics['penalty'], ".4f")
|
||||
if 'hloss' in metrics:
|
||||
losses['hloss'] = format(metrics['hloss'], ".4f")
|
||||
return losses
|
||||
|
||||
def _format_test(self, metrics: dict) -> dict:
|
||||
"""Formatting for test metrics."""
|
||||
losses = {}
|
||||
if 'sdr' in metrics:
|
||||
losses['sdr'] = format(metrics['sdr'], '.3f')
|
||||
if 'nsdr' in metrics:
|
||||
losses['nsdr'] = format(metrics['nsdr'], '.3f')
|
||||
for source in self.model.sources:
|
||||
key = f'sdr_{source}'
|
||||
if key in metrics:
|
||||
losses[key] = format(metrics[key], '.3f')
|
||||
key = f'nsdr_{source}'
|
||||
if key in metrics:
|
||||
losses[key] = format(metrics[key], '.3f')
|
||||
return losses
|
||||
|
||||
def train(self):
|
||||
# Optimizing the model
|
||||
if self.history:
|
||||
logger.info("Replaying metrics from previous run")
|
||||
for epoch, metrics in enumerate(self.history):
|
||||
formatted = self._format_train(metrics['train'])
|
||||
logger.info(
|
||||
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
||||
formatted = self._format_train(metrics['valid'])
|
||||
logger.info(
|
||||
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
||||
if 'test' in metrics:
|
||||
formatted = self._format_test(metrics['test'])
|
||||
if formatted:
|
||||
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
|
||||
|
||||
epoch = 0
|
||||
for epoch in range(len(self.history), self.args.epochs):
|
||||
# Train one epoch
|
||||
self.model.train() # Turn on BatchNorm & Dropout
|
||||
metrics = {}
|
||||
logger.info('-' * 70)
|
||||
logger.info("Training...")
|
||||
metrics['train'] = self._run_one_epoch(epoch)
|
||||
formatted = self._format_train(metrics['train'])
|
||||
logger.info(
|
||||
bold(f'Train Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
||||
|
||||
# Cross validation
|
||||
logger.info('-' * 70)
|
||||
logger.info('Cross validation...')
|
||||
self.model.eval() # Turn off Batchnorm & Dropout
|
||||
with torch.no_grad():
|
||||
valid = self._run_one_epoch(epoch, train=False)
|
||||
bvalid = valid
|
||||
bname = 'main'
|
||||
state = states.copy_state(self.model.state_dict())
|
||||
metrics['valid'] = {}
|
||||
metrics['valid']['main'] = valid
|
||||
key = self.args.test.metric
|
||||
for kind, emas in self.emas.items():
|
||||
for k, ema in enumerate(emas):
|
||||
with ema.swap():
|
||||
valid = self._run_one_epoch(epoch, train=False)
|
||||
name = f'ema_{kind}_{k}'
|
||||
metrics['valid'][name] = valid
|
||||
a = valid[key]
|
||||
b = bvalid[key]
|
||||
if key.startswith('nsdr'):
|
||||
a = -a
|
||||
b = -b
|
||||
if a < b:
|
||||
bvalid = valid
|
||||
state = ema.state
|
||||
bname = name
|
||||
metrics['valid'].update(bvalid)
|
||||
metrics['valid']['bname'] = bname
|
||||
|
||||
valid_loss = metrics['valid'][key]
|
||||
mets = pull_metric(self.link.history, f'valid.{key}') + [valid_loss]
|
||||
if key.startswith('nsdr'):
|
||||
best_loss = max(mets)
|
||||
else:
|
||||
best_loss = min(mets)
|
||||
metrics['valid']['best'] = best_loss
|
||||
if self.args.svd.penalty > 0:
|
||||
kw = dict(self.args.svd)
|
||||
kw.pop('penalty')
|
||||
with torch.no_grad():
|
||||
penalty = svd_penalty(self.model, exact=True, **kw)
|
||||
metrics['valid']['penalty'] = penalty
|
||||
|
||||
formatted = self._format_train(metrics['valid'])
|
||||
logger.info(
|
||||
bold(f'Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}'))
|
||||
|
||||
# Save the best model
|
||||
if valid_loss == best_loss or self.args.dset.train_valid:
|
||||
logger.info(bold('New best valid loss %.4f'), valid_loss)
|
||||
self.best_state = states.copy_state(state)
|
||||
self.best_changed = True
|
||||
|
||||
# Eval model every `test.every` epoch or on last epoch
|
||||
should_eval = (epoch + 1) % self.args.test.every == 0
|
||||
is_last = epoch == self.args.epochs - 1
|
||||
# # Tries to detect divergence in a reliable way and finish job
|
||||
# # not to waste compute.
|
||||
# # Commented out as this was super specific to the MDX competition.
|
||||
# reco = metrics['valid']['main']['reco']
|
||||
# div = epoch >= 180 and reco > 0.18
|
||||
# div = div or epoch >= 100 and reco > 0.25
|
||||
# div = div and self.args.optim.loss == 'l1'
|
||||
# if div:
|
||||
# logger.warning("Finishing training early because valid loss is too high.")
|
||||
# is_last = True
|
||||
if should_eval or is_last:
|
||||
# Evaluate on the testset
|
||||
logger.info('-' * 70)
|
||||
logger.info('Evaluating on the test set...')
|
||||
# We switch to the best known model for testing
|
||||
if self.args.test.best:
|
||||
state = self.best_state
|
||||
else:
|
||||
state = states.copy_state(self.model.state_dict())
|
||||
compute_sdr = self.args.test.sdr and is_last
|
||||
with states.swap_state(self.model, state):
|
||||
with torch.no_grad():
|
||||
metrics['test'] = evaluate(self, compute_sdr=compute_sdr)
|
||||
formatted = self._format_test(metrics['test'])
|
||||
logger.info(bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}"))
|
||||
self.link.push_metrics(metrics)
|
||||
|
||||
if distrib.rank == 0:
|
||||
# Save model each epoch
|
||||
self._serialize(epoch)
|
||||
logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve())
|
||||
if is_last:
|
||||
break
|
||||
|
||||
def _run_one_epoch(self, epoch, train=True):
|
||||
args = self.args
|
||||
data_loader = self.loaders['train'] if train else self.loaders['valid']
|
||||
if distrib.world_size > 1 and train:
|
||||
data_loader.sampler.set_epoch(epoch)
|
||||
|
||||
label = ["Valid", "Train"][train]
|
||||
name = label + f" | Epoch {epoch + 1}"
|
||||
total = len(data_loader)
|
||||
if args.max_batches:
|
||||
total = min(total, args.max_batches)
|
||||
logprog = LogProgress(logger, data_loader, total=total,
|
||||
updates=self.args.misc.num_prints, name=name)
|
||||
averager = EMA()
|
||||
|
||||
for idx, sources in enumerate(logprog):
|
||||
sources = sources.to(self.device)
|
||||
if train:
|
||||
sources = self.augment(sources)
|
||||
mix = sources.sum(dim=1)
|
||||
else:
|
||||
mix = sources[:, 0]
|
||||
sources = sources[:, 1:]
|
||||
|
||||
if not train and self.args.valid_apply:
|
||||
estimate = apply_model(self.model, mix, split=self.args.test.split, overlap=0)
|
||||
else:
|
||||
estimate = self.dmodel(mix)
|
||||
if train and hasattr(self.model, 'transform_target'):
|
||||
sources = self.model.transform_target(mix, sources)
|
||||
assert estimate.shape == sources.shape, (estimate.shape, sources.shape)
|
||||
dims = tuple(range(2, sources.dim()))
|
||||
|
||||
if args.optim.loss == 'l1':
|
||||
loss = F.l1_loss(estimate, sources, reduction='none')
|
||||
loss = loss.mean(dims).mean(0)
|
||||
reco = loss
|
||||
elif args.optim.loss == 'mse':
|
||||
loss = F.mse_loss(estimate, sources, reduction='none')
|
||||
loss = loss.mean(dims)
|
||||
reco = loss**0.5
|
||||
reco = reco.mean(0)
|
||||
else:
|
||||
raise ValueError(f"Invalid loss {self.args.loss}")
|
||||
weights = torch.tensor(args.weights).to(sources)
|
||||
loss = (loss * weights).sum() / weights.sum()
|
||||
|
||||
ms = 0
|
||||
if self.quantizer is not None:
|
||||
ms = self.quantizer.model_size()
|
||||
if args.quant.diffq:
|
||||
loss += args.quant.diffq * ms
|
||||
|
||||
losses = {}
|
||||
losses['reco'] = (reco * weights).sum() / weights.sum()
|
||||
losses['ms'] = ms
|
||||
|
||||
if not train:
|
||||
nsdrs = new_sdr(sources, estimate.detach()).mean(0)
|
||||
total = 0
|
||||
for source, nsdr, w in zip(self.model.sources, nsdrs, weights):
|
||||
losses[f'nsdr_{source}'] = nsdr
|
||||
total += w * nsdr
|
||||
losses['nsdr'] = total / weights.sum()
|
||||
|
||||
if train and args.svd.penalty > 0:
|
||||
kw = dict(args.svd)
|
||||
kw.pop('penalty')
|
||||
penalty = svd_penalty(self.model, **kw)
|
||||
losses['penalty'] = penalty
|
||||
loss += args.svd.penalty * penalty
|
||||
|
||||
losses['loss'] = loss
|
||||
|
||||
for k, source in enumerate(self.model.sources):
|
||||
losses[f'reco_{source}'] = reco[k]
|
||||
|
||||
# optimize model in training mode
|
||||
if train:
|
||||
loss.backward()
|
||||
grad_norm = 0
|
||||
grads = []
|
||||
for p in self.model.parameters():
|
||||
if p.grad is not None:
|
||||
grad_norm += p.grad.data.norm()**2
|
||||
grads.append(p.grad.data)
|
||||
losses['grad'] = grad_norm ** 0.5
|
||||
if args.optim.clip_grad:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
args.optim.clip_grad)
|
||||
|
||||
if self.args.flag == 'uns':
|
||||
for n, p in self.model.named_parameters():
|
||||
if p.grad is None:
|
||||
print('no grad', n)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
for ema in self.emas['batch']:
|
||||
ema.update()
|
||||
losses = averager(losses)
|
||||
logs = self._format_train(losses)
|
||||
logprog.update(**logs)
|
||||
# Just in case, clear some memory
|
||||
del loss, estimate, reco, ms
|
||||
if args.max_batches == idx:
|
||||
break
|
||||
if self.args.debug and train:
|
||||
break
|
||||
if self.args.flag == 'debug':
|
||||
break
|
||||
if train:
|
||||
for ema in self.emas['epoch']:
|
||||
ema.update()
|
||||
return distrib.average(losses, idx + 1)
|
||||
Reference in New Issue
Block a user