generated from thinkode/modelRepository
Initial commit and v1.0
This commit is contained in:
0
demucs/grids/__init__.py
Normal file
0
demucs/grids/__init__.py
Normal file
64
demucs/grids/_explorers.py
Normal file
64
demucs/grids/_explorers.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from dora import Explorer
|
||||
import treetable as tt
|
||||
|
||||
|
||||
class MyExplorer(Explorer):
|
||||
test_metrics = ['nsdr', 'sdr_med']
|
||||
|
||||
def get_grid_metrics(self):
|
||||
"""Return the metrics that should be displayed in the tracking table.
|
||||
"""
|
||||
return [
|
||||
tt.group("train", [
|
||||
tt.leaf("epoch"),
|
||||
tt.leaf("reco", ".3f"),
|
||||
], align=">"),
|
||||
tt.group("valid", [
|
||||
tt.leaf("penalty", ".1f"),
|
||||
tt.leaf("ms", ".1f"),
|
||||
tt.leaf("reco", ".2%"),
|
||||
tt.leaf("breco", ".2%"),
|
||||
tt.leaf("b_nsdr", ".2f"),
|
||||
# tt.leaf("b_nsdr_drums", ".2f"),
|
||||
# tt.leaf("b_nsdr_bass", ".2f"),
|
||||
# tt.leaf("b_nsdr_other", ".2f"),
|
||||
# tt.leaf("b_nsdr_vocals", ".2f"),
|
||||
], align=">"),
|
||||
tt.group("test", [
|
||||
tt.leaf(name, ".2f")
|
||||
for name in self.test_metrics
|
||||
], align=">")
|
||||
]
|
||||
|
||||
def process_history(self, history):
|
||||
train = {
|
||||
'epoch': len(history),
|
||||
}
|
||||
valid = {}
|
||||
test = {}
|
||||
best_v_main = float('inf')
|
||||
breco = float('inf')
|
||||
for metrics in history:
|
||||
train.update(metrics['train'])
|
||||
valid.update(metrics['valid'])
|
||||
if 'main' in metrics['valid']:
|
||||
best_v_main = min(best_v_main, metrics['valid']['main']['loss'])
|
||||
valid['bmain'] = best_v_main
|
||||
valid['breco'] = min(breco, metrics['valid']['reco'])
|
||||
breco = valid['breco']
|
||||
if (metrics['valid']['loss'] == metrics['valid']['best'] or
|
||||
metrics['valid'].get('nsdr') == metrics['valid']['best']):
|
||||
for k, v in metrics['valid'].items():
|
||||
if k.startswith('reco_'):
|
||||
valid['b_' + k[len('reco_'):]] = v
|
||||
if k.startswith('nsdr'):
|
||||
valid[f'b_{k}'] = v
|
||||
if 'test' in metrics:
|
||||
test.update(metrics['test'])
|
||||
metrics = history[-1]
|
||||
return {"train": train, "valid": valid, "test": test}
|
||||
33
demucs/grids/mdx.py
Normal file
33
demucs/grids/mdx.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Main training for the Track A MDX models.
|
||||
"""
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from ..train import main
|
||||
|
||||
|
||||
TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68']
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher):
|
||||
launcher.slurm_(
|
||||
gpus=8,
|
||||
time=3 * 24 * 60,
|
||||
partition='learnlab')
|
||||
|
||||
# Reproduce results from MDX competition Track A
|
||||
# This trains the first round of models. Once this is trained,
|
||||
# you will need to schedule `mdx_refine`.
|
||||
for sig in TRACK_A:
|
||||
xp = main.get_xp_from_sig(sig)
|
||||
parent = xp.cfg.continue_from
|
||||
xp = main.get_xp_from_sig(parent)
|
||||
launcher(xp.argv)
|
||||
launcher(xp.argv, {'quant.diffq': 1e-4})
|
||||
launcher(xp.argv, {'quant.diffq': 3e-4})
|
||||
36
demucs/grids/mdx_extra.py
Normal file
36
demucs/grids/mdx_extra.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Main training for the Track A MDX models.
|
||||
"""
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from ..train import main
|
||||
|
||||
TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08']
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher):
|
||||
launcher.slurm_(
|
||||
gpus=8,
|
||||
time=3 * 24 * 60,
|
||||
partition='learnlab')
|
||||
|
||||
# Reproduce results from MDX competition Track A
|
||||
# This trains the first round of models. Once this is trained,
|
||||
# you will need to schedule `mdx_refine`.
|
||||
for sig in TRACK_B:
|
||||
while sig is not None:
|
||||
xp = main.get_xp_from_sig(sig)
|
||||
sig = xp.cfg.continue_from
|
||||
|
||||
for dset in ['extra44', 'extra_test']:
|
||||
sub = launcher.bind(xp.argv, dset=dset)
|
||||
sub()
|
||||
if dset == 'extra_test':
|
||||
sub({'quant.diffq': 1e-4})
|
||||
sub({'quant.diffq': 3e-4})
|
||||
34
demucs/grids/mdx_refine.py
Normal file
34
demucs/grids/mdx_refine.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Main training for the Track A MDX models.
|
||||
"""
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from .mdx import TRACK_A
|
||||
from ..train import main
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher):
|
||||
launcher.slurm_(
|
||||
gpus=8,
|
||||
time=3 * 24 * 60,
|
||||
partition='learnlab')
|
||||
|
||||
# Reproduce results from MDX competition Track A
|
||||
# WARNING: all the experiments in the `mdx` grid must have completed.
|
||||
for sig in TRACK_A:
|
||||
xp = main.get_xp_from_sig(sig)
|
||||
launcher(xp.argv)
|
||||
for diffq in [1e-4, 3e-4]:
|
||||
xp_src = main.get_xp_from_sig(xp.cfg.continue_from)
|
||||
q_argv = [f'quant.diffq={diffq}']
|
||||
actual_src = main.get_xp(xp_src.argv + q_argv)
|
||||
actual_src.link.load()
|
||||
assert len(actual_src.link.history) == actual_src.cfg.epochs
|
||||
argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"']
|
||||
launcher(argv)
|
||||
69
demucs/grids/mmi.py
Normal file
69
demucs/grids/mmi.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from dora import Launcher
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher: Launcher):
|
||||
launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
|
||||
|
||||
sub = launcher.bind_(
|
||||
{
|
||||
"dset": "extra_mmi_goodclean",
|
||||
"test.shifts": 0,
|
||||
"model": "htdemucs",
|
||||
"htdemucs.dconv_mode": 3,
|
||||
"htdemucs.depth": 4,
|
||||
"htdemucs.t_dropout": 0.02,
|
||||
"htdemucs.t_layers": 5,
|
||||
"max_batches": 800,
|
||||
"ema.epoch": [0.9, 0.95],
|
||||
"ema.batch": [0.9995, 0.9999],
|
||||
"dset.segment": 10,
|
||||
"batch_size": 32,
|
||||
}
|
||||
)
|
||||
sub({"model": "hdemucs"})
|
||||
sub({"model": "hdemucs", "dset": "extra44"})
|
||||
sub({"model": "hdemucs", "dset": "musdb44"})
|
||||
|
||||
sparse = {
|
||||
'batch_size': 3 * 8,
|
||||
'augment.remix.group_size': 3,
|
||||
'htdemucs.t_auto_sparsity': True,
|
||||
'htdemucs.t_sparse_self_attn': True,
|
||||
'htdemucs.t_sparse_cross_attn': True,
|
||||
'htdemucs.t_sparsity': 0.9,
|
||||
"htdemucs.t_layers": 7
|
||||
}
|
||||
|
||||
with launcher.job_array():
|
||||
for transf_layers in [5, 7]:
|
||||
for bottom_channels in [0, 512]:
|
||||
sub = launcher.bind({
|
||||
"htdemucs.t_layers": transf_layers,
|
||||
"htdemucs.bottom_channels": bottom_channels,
|
||||
})
|
||||
if bottom_channels == 0 and transf_layers == 5:
|
||||
sub({"augment.remix.proba": 0.0})
|
||||
sub({
|
||||
"augment.repitch.proba": 0.0,
|
||||
# when doing repitching, we trim the outut to align on the
|
||||
# highest change of BPM. When removing repitching,
|
||||
# we simulate it here to ensure the training context is the same.
|
||||
# Another second is lost for all experiments due to the random
|
||||
# shift augmentation.
|
||||
"dset.segment": 10 * 0.88})
|
||||
elif bottom_channels == 512 and transf_layers == 5:
|
||||
sub(dset="musdb44")
|
||||
sub(dset="extra44")
|
||||
# Sparse kernel XP, currently not released as kernels are still experimental.
|
||||
sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7})
|
||||
|
||||
for duration in [5, 10, 15]:
|
||||
sub({"dset.segment": duration})
|
||||
55
demucs/grids/mmi_ft.py
Normal file
55
demucs/grids/mmi_ft.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from dora import Launcher
|
||||
from demucs import train
|
||||
|
||||
|
||||
def get_sub(launcher, sig):
|
||||
xp = train.main.get_xp_from_sig(sig)
|
||||
sub = launcher.bind(xp.argv)
|
||||
sub()
|
||||
sub.bind_({
|
||||
'continue_from': sig,
|
||||
'continue_best': True})
|
||||
return sub
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher: Launcher):
|
||||
launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days
|
||||
ft = {
|
||||
'optim.lr': 1e-4,
|
||||
'augment.remix.proba': 0,
|
||||
'augment.scale.proba': 0,
|
||||
'augment.shift_same': True,
|
||||
'htdemucs.t_weight_decay': 0.05,
|
||||
'batch_size': 8,
|
||||
'optim.clip_grad': 5,
|
||||
'optim.optim': 'adamw',
|
||||
'epochs': 50,
|
||||
'dset.wav2_valid': True,
|
||||
'ema.epoch': [], # let's make valid a bit faster
|
||||
}
|
||||
with launcher.job_array():
|
||||
for sig in ['2899e11a']:
|
||||
sub = get_sub(launcher, sig)
|
||||
sub.bind_(ft)
|
||||
for segment in [15, 18]:
|
||||
for source in range(4):
|
||||
w = [0] * 4
|
||||
w[source] = 1
|
||||
sub({'weights': w, 'dset.segment': segment})
|
||||
|
||||
for sig in ['955717e8']:
|
||||
sub = get_sub(launcher, sig)
|
||||
sub.bind_(ft)
|
||||
for segment in [10, 15]:
|
||||
for source in range(4):
|
||||
w = [0] * 4
|
||||
w[source] = 1
|
||||
sub({'weights': w, 'dset.segment': segment})
|
||||
50
demucs/grids/repro.py
Normal file
50
demucs/grids/repro.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Easier training for reproducibility
|
||||
"""
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher):
|
||||
launcher.slurm_(
|
||||
gpus=8,
|
||||
time=3 * 24 * 60,
|
||||
partition='devlab,learnlab')
|
||||
|
||||
launcher.bind_({'ema.epoch': [0.9, 0.95]})
|
||||
launcher.bind_({'ema.batch': [0.9995, 0.9999]})
|
||||
launcher.bind_({'epochs': 600})
|
||||
|
||||
base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False,
|
||||
'demucs.lstm_layers': 2}
|
||||
newt = {'model': 'demucs', 'demucs.normalize': True}
|
||||
hdem = {'model': 'hdemucs'}
|
||||
svd = {'svd.penalty': 1e-5, 'svd': 'base2'}
|
||||
|
||||
with launcher.job_array():
|
||||
for model in [base, newt, hdem]:
|
||||
sub = launcher.bind(model)
|
||||
if model is base:
|
||||
# Training the v2 Demucs on MusDB HQ
|
||||
sub(epochs=360)
|
||||
continue
|
||||
|
||||
# those two will be used in the repro_mdx_a bag of models.
|
||||
sub(svd)
|
||||
sub(svd, seed=43)
|
||||
if model == newt:
|
||||
# Ablation study
|
||||
sub()
|
||||
abl = sub.bind(svd)
|
||||
abl({'ema.epoch': [], 'ema.batch': []})
|
||||
abl({'demucs.dconv_lstm': 10})
|
||||
abl({'demucs.dconv_attn': 10})
|
||||
abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2})
|
||||
abl({'demucs.dconv_mode': 0})
|
||||
abl({'demucs.gelu': False})
|
||||
46
demucs/grids/repro_ft.py
Normal file
46
demucs/grids/repro_ft.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Fine tuning experiments
|
||||
"""
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from ..train import main
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher):
|
||||
launcher.slurm_(
|
||||
gpus=8,
|
||||
time=300,
|
||||
partition='devlab,learnlab')
|
||||
|
||||
# Mus
|
||||
launcher.slurm_(constraint='volta32gb')
|
||||
|
||||
grid = "repro"
|
||||
folder = main.dora.dir / "grids" / grid
|
||||
|
||||
for sig in folder.iterdir():
|
||||
if not sig.is_symlink():
|
||||
continue
|
||||
xp = main.get_xp_from_sig(sig)
|
||||
xp.link.load()
|
||||
if len(xp.link.history) != xp.cfg.epochs:
|
||||
continue
|
||||
sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"'])
|
||||
sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]})
|
||||
sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4})
|
||||
sub.bind_({'dset.segment': 28, 'dset.shift': 2})
|
||||
sub.bind_({'batch_size': 32})
|
||||
auto = {'dset': 'auto_mus'}
|
||||
auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0,
|
||||
'augment.shift_same': True})
|
||||
sub.bind_(auto)
|
||||
sub.bind_({'batch_size': 16})
|
||||
sub.bind_({'optim.lr': 1e-4})
|
||||
sub.bind_({'model_segment': 44})
|
||||
sub()
|
||||
19
demucs/grids/sdx23.py
Normal file
19
demucs/grids/sdx23.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from ._explorers import MyExplorer
|
||||
from dora import Launcher
|
||||
|
||||
|
||||
@MyExplorer
|
||||
def explorer(launcher: Launcher):
|
||||
launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair",
|
||||
mem_per_gpu=None, constraint='')
|
||||
launcher.bind_({"dset.use_musdb": False})
|
||||
|
||||
with launcher.job_array():
|
||||
launcher(dset='sdx23_bleeding')
|
||||
launcher(dset='sdx23_labelnoise')
|
||||
Reference in New Issue
Block a user