generated from thinkode/modelRepository
Initial commit and v1.0
This commit is contained in:
174
demucs/evaluate.py
Normal file
174
demucs/evaluate.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Test time evaluation, either using the original SDR from [Vincent et al. 2006]
|
||||
or the newest SDR definition from the MDX 2021 competition (this one will
|
||||
be reported as `nsdr` for `new sdr`).
|
||||
"""
|
||||
|
||||
from concurrent import futures
|
||||
import logging
|
||||
|
||||
from dora.log import LogProgress
|
||||
import numpy as np
|
||||
import musdb
|
||||
import museval
|
||||
import torch as th
|
||||
|
||||
from .apply import apply_model
|
||||
from .audio import convert_audio, save_audio
|
||||
from . import distrib
|
||||
from .utils import DummyPoolExecutor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def new_sdr(references, estimates):
|
||||
"""
|
||||
Compute the SDR according to the MDX challenge definition.
|
||||
Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license)
|
||||
"""
|
||||
assert references.dim() == 4
|
||||
assert estimates.dim() == 4
|
||||
delta = 1e-7 # avoid numerical errors
|
||||
num = th.sum(th.square(references), dim=(2, 3))
|
||||
den = th.sum(th.square(references - estimates), dim=(2, 3))
|
||||
num += delta
|
||||
den += delta
|
||||
scores = 10 * th.log10(num / den)
|
||||
return scores
|
||||
|
||||
|
||||
def eval_track(references, estimates, win, hop, compute_sdr=True):
|
||||
references = references.transpose(1, 2).double()
|
||||
estimates = estimates.transpose(1, 2).double()
|
||||
|
||||
new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0]
|
||||
|
||||
if not compute_sdr:
|
||||
return None, new_scores
|
||||
else:
|
||||
references = references.numpy()
|
||||
estimates = estimates.numpy()
|
||||
scores = museval.metrics.bss_eval(
|
||||
references, estimates,
|
||||
compute_permutation=False,
|
||||
window=win,
|
||||
hop=hop,
|
||||
framewise_filters=False,
|
||||
bsseval_sources_version=False)[:-1]
|
||||
return scores, new_scores
|
||||
|
||||
|
||||
def evaluate(solver, compute_sdr=False):
|
||||
"""
|
||||
Evaluate model using museval.
|
||||
compute_sdr=False means using only the MDX definition of the SDR, which
|
||||
is much faster to evaluate.
|
||||
"""
|
||||
|
||||
args = solver.args
|
||||
|
||||
output_dir = solver.folder / "results"
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
json_folder = solver.folder / "results/test"
|
||||
json_folder.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# we load tracks from the original musdb set
|
||||
if args.test.nonhq is None:
|
||||
test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True)
|
||||
else:
|
||||
test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False)
|
||||
src_rate = args.dset.musdb_samplerate
|
||||
|
||||
eval_device = 'cpu'
|
||||
|
||||
model = solver.model
|
||||
win = int(1. * model.samplerate)
|
||||
hop = int(1. * model.samplerate)
|
||||
|
||||
indexes = range(distrib.rank, len(test_set), distrib.world_size)
|
||||
indexes = LogProgress(logger, indexes, updates=args.misc.num_prints,
|
||||
name='Eval')
|
||||
pendings = []
|
||||
|
||||
pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor
|
||||
with pool(args.test.workers) as pool:
|
||||
for index in indexes:
|
||||
track = test_set.tracks[index]
|
||||
|
||||
mix = th.from_numpy(track.audio).t().float()
|
||||
if mix.dim() == 1:
|
||||
mix = mix[None]
|
||||
mix = mix.to(solver.device)
|
||||
ref = mix.mean(dim=0) # mono mixture
|
||||
mix = (mix - ref.mean()) / ref.std()
|
||||
mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels)
|
||||
estimates = apply_model(model, mix[None],
|
||||
shifts=args.test.shifts, split=args.test.split,
|
||||
overlap=args.test.overlap)[0]
|
||||
estimates = estimates * ref.std() + ref.mean()
|
||||
estimates = estimates.to(eval_device)
|
||||
|
||||
references = th.stack(
|
||||
[th.from_numpy(track.targets[name].audio).t() for name in model.sources])
|
||||
if references.dim() == 2:
|
||||
references = references[:, None]
|
||||
references = references.to(eval_device)
|
||||
references = convert_audio(references, src_rate,
|
||||
model.samplerate, model.audio_channels)
|
||||
if args.test.save:
|
||||
folder = solver.folder / "wav" / track.name
|
||||
folder.mkdir(exist_ok=True, parents=True)
|
||||
for name, estimate in zip(model.sources, estimates):
|
||||
save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate)
|
||||
|
||||
pendings.append((track.name, pool.submit(
|
||||
eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr)))
|
||||
|
||||
pendings = LogProgress(logger, pendings, updates=args.misc.num_prints,
|
||||
name='Eval (BSS)')
|
||||
tracks = {}
|
||||
for track_name, pending in pendings:
|
||||
pending = pending.result()
|
||||
scores, nsdrs = pending
|
||||
tracks[track_name] = {}
|
||||
for idx, target in enumerate(model.sources):
|
||||
tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]}
|
||||
if scores is not None:
|
||||
(sdr, isr, sir, sar) = scores
|
||||
for idx, target in enumerate(model.sources):
|
||||
values = {
|
||||
"SDR": sdr[idx].tolist(),
|
||||
"SIR": sir[idx].tolist(),
|
||||
"ISR": isr[idx].tolist(),
|
||||
"SAR": sar[idx].tolist()
|
||||
}
|
||||
tracks[track_name][target].update(values)
|
||||
|
||||
all_tracks = {}
|
||||
for src in range(distrib.world_size):
|
||||
all_tracks.update(distrib.share(tracks, src))
|
||||
|
||||
result = {}
|
||||
metric_names = next(iter(all_tracks.values()))[model.sources[0]]
|
||||
for metric_name in metric_names:
|
||||
avg = 0
|
||||
avg_of_medians = 0
|
||||
for source in model.sources:
|
||||
medians = [
|
||||
np.nanmedian(all_tracks[track][source][metric_name])
|
||||
for track in all_tracks.keys()]
|
||||
mean = np.mean(medians)
|
||||
median = np.median(medians)
|
||||
result[metric_name.lower() + "_" + source] = mean
|
||||
result[metric_name.lower() + "_med" + "_" + source] = median
|
||||
avg += mean / len(model.sources)
|
||||
avg_of_medians += median / len(model.sources)
|
||||
result[metric_name.lower()] = avg
|
||||
result[metric_name.lower() + "_med"] = avg_of_medians
|
||||
return result
|
||||
Reference in New Issue
Block a user