generated from thinkode/modelRepository
Initial commit and v1.0
This commit is contained in:
254
demucs/wav.py
Normal file
254
demucs/wav.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Loading wav based datasets, including MusdbHQ."""
|
||||
|
||||
from collections import OrderedDict
|
||||
import hashlib
|
||||
import math
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
|
||||
import musdb
|
||||
import julius
|
||||
import torch as th
|
||||
from torch import distributed
|
||||
import torchaudio as ta
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .audio import convert_audio_channels
|
||||
from . import distrib
|
||||
|
||||
MIXTURE = "mixture"
|
||||
EXT = ".wav"
|
||||
|
||||
|
||||
def _track_metadata(track, sources, normalize=True, ext=EXT):
|
||||
track_length = None
|
||||
track_samplerate = None
|
||||
mean = 0
|
||||
std = 1
|
||||
for source in sources + [MIXTURE]:
|
||||
file = track / f"{source}{ext}"
|
||||
if source == MIXTURE and not file.exists():
|
||||
audio = 0
|
||||
for sub_source in sources:
|
||||
sub_file = track / f"{sub_source}{ext}"
|
||||
sub_audio, sr = ta.load(sub_file)
|
||||
audio += sub_audio
|
||||
would_clip = audio.abs().max() >= 1
|
||||
if would_clip:
|
||||
assert ta.get_audio_backend() == 'soundfile', 'use dset.backend=soundfile'
|
||||
ta.save(file, audio, sr, encoding='PCM_F')
|
||||
|
||||
try:
|
||||
info = ta.info(str(file))
|
||||
except RuntimeError:
|
||||
print(file)
|
||||
raise
|
||||
length = info.num_frames
|
||||
if track_length is None:
|
||||
track_length = length
|
||||
track_samplerate = info.sample_rate
|
||||
elif track_length != length:
|
||||
raise ValueError(
|
||||
f"Invalid length for file {file}: "
|
||||
f"expecting {track_length} but got {length}.")
|
||||
elif info.sample_rate != track_samplerate:
|
||||
raise ValueError(
|
||||
f"Invalid sample rate for file {file}: "
|
||||
f"expecting {track_samplerate} but got {info.sample_rate}.")
|
||||
if source == MIXTURE and normalize:
|
||||
try:
|
||||
wav, _ = ta.load(str(file))
|
||||
except RuntimeError:
|
||||
print(file)
|
||||
raise
|
||||
wav = wav.mean(0)
|
||||
mean = wav.mean().item()
|
||||
std = wav.std().item()
|
||||
|
||||
return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate}
|
||||
|
||||
|
||||
def build_metadata(path, sources, normalize=True, ext=EXT):
|
||||
"""
|
||||
Build the metadata for `Wavset`.
|
||||
|
||||
Args:
|
||||
path (str or Path): path to dataset.
|
||||
sources (list[str]): list of sources to look for.
|
||||
normalize (bool): if True, loads full track and store normalization
|
||||
values based on the mixture file.
|
||||
ext (str): extension of audio files (default is .wav).
|
||||
"""
|
||||
|
||||
meta = {}
|
||||
path = Path(path)
|
||||
pendings = []
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(8) as pool:
|
||||
for root, folders, files in os.walk(path, followlinks=True):
|
||||
root = Path(root)
|
||||
if root.name.startswith('.') or folders or root == path:
|
||||
continue
|
||||
name = str(root.relative_to(path))
|
||||
pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext)))
|
||||
# meta[name] = _track_metadata(root, sources, normalize, ext)
|
||||
for name, pending in tqdm.tqdm(pendings, ncols=120):
|
||||
meta[name] = pending.result()
|
||||
return meta
|
||||
|
||||
|
||||
class Wavset:
|
||||
def __init__(
|
||||
self,
|
||||
root, metadata, sources,
|
||||
segment=None, shift=None, normalize=True,
|
||||
samplerate=44100, channels=2, ext=EXT):
|
||||
"""
|
||||
Waveset (or mp3 set for that matter). Can be used to train
|
||||
with arbitrary sources. Each track should be one folder inside of `path`.
|
||||
The folder should contain files named `{source}.{ext}`.
|
||||
|
||||
Args:
|
||||
root (Path or str): root folder for the dataset.
|
||||
metadata (dict): output from `build_metadata`.
|
||||
sources (list[str]): list of source names.
|
||||
segment (None or float): segment length in seconds. If `None`, returns entire tracks.
|
||||
shift (None or float): stride in seconds bewteen samples.
|
||||
normalize (bool): normalizes input audio, **based on the metadata content**,
|
||||
i.e. the entire track is normalized, not individual extracts.
|
||||
samplerate (int): target sample rate. if the file sample rate
|
||||
is different, it will be resampled on the fly.
|
||||
channels (int): target nb of channels. if different, will be
|
||||
changed onthe fly.
|
||||
ext (str): extension for audio files (default is .wav).
|
||||
|
||||
samplerate and channels are converted on the fly.
|
||||
"""
|
||||
self.root = Path(root)
|
||||
self.metadata = OrderedDict(metadata)
|
||||
self.segment = segment
|
||||
self.shift = shift or segment
|
||||
self.normalize = normalize
|
||||
self.sources = sources
|
||||
self.channels = channels
|
||||
self.samplerate = samplerate
|
||||
self.ext = ext
|
||||
self.num_examples = []
|
||||
for name, meta in self.metadata.items():
|
||||
track_duration = meta['length'] / meta['samplerate']
|
||||
if segment is None or track_duration < segment:
|
||||
examples = 1
|
||||
else:
|
||||
examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1)
|
||||
self.num_examples.append(examples)
|
||||
|
||||
def __len__(self):
|
||||
return sum(self.num_examples)
|
||||
|
||||
def get_file(self, name, source):
|
||||
return self.root / name / f"{source}{self.ext}"
|
||||
|
||||
def __getitem__(self, index):
|
||||
for name, examples in zip(self.metadata, self.num_examples):
|
||||
if index >= examples:
|
||||
index -= examples
|
||||
continue
|
||||
meta = self.metadata[name]
|
||||
num_frames = -1
|
||||
offset = 0
|
||||
if self.segment is not None:
|
||||
offset = int(meta['samplerate'] * self.shift * index)
|
||||
num_frames = int(math.ceil(meta['samplerate'] * self.segment))
|
||||
wavs = []
|
||||
for source in self.sources:
|
||||
file = self.get_file(name, source)
|
||||
wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames)
|
||||
wav = convert_audio_channels(wav, self.channels)
|
||||
wavs.append(wav)
|
||||
|
||||
example = th.stack(wavs)
|
||||
example = julius.resample_frac(example, meta['samplerate'], self.samplerate)
|
||||
if self.normalize:
|
||||
example = (example - meta['mean']) / meta['std']
|
||||
if self.segment:
|
||||
length = int(self.segment * self.samplerate)
|
||||
example = example[..., :length]
|
||||
example = F.pad(example, (0, length - example.shape[-1]))
|
||||
return example
|
||||
|
||||
|
||||
def get_wav_datasets(args, name='wav'):
|
||||
"""Extract the wav datasets from the XP arguments."""
|
||||
path = getattr(args, name)
|
||||
sig = hashlib.sha1(str(path).encode()).hexdigest()[:8]
|
||||
metadata_file = Path(args.metadata) / ('wav_' + sig + ".json")
|
||||
train_path = Path(path) / "train"
|
||||
valid_path = Path(path) / "valid"
|
||||
if not metadata_file.is_file() and distrib.rank == 0:
|
||||
metadata_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
train = build_metadata(train_path, args.sources)
|
||||
valid = build_metadata(valid_path, args.sources)
|
||||
json.dump([train, valid], open(metadata_file, "w"))
|
||||
if distrib.world_size > 1:
|
||||
distributed.barrier()
|
||||
train, valid = json.load(open(metadata_file))
|
||||
if args.full_cv:
|
||||
kw_cv = {}
|
||||
else:
|
||||
kw_cv = {'segment': args.segment, 'shift': args.shift}
|
||||
train_set = Wavset(train_path, train, args.sources,
|
||||
segment=args.segment, shift=args.shift,
|
||||
samplerate=args.samplerate, channels=args.channels,
|
||||
normalize=args.normalize)
|
||||
valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources),
|
||||
samplerate=args.samplerate, channels=args.channels,
|
||||
normalize=args.normalize, **kw_cv)
|
||||
return train_set, valid_set
|
||||
|
||||
|
||||
def _get_musdb_valid():
|
||||
# Return musdb valid set.
|
||||
import yaml
|
||||
setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml'
|
||||
setup = yaml.safe_load(open(setup_path, 'r'))
|
||||
return setup['validation_tracks']
|
||||
|
||||
|
||||
def get_musdb_wav_datasets(args):
|
||||
"""Extract the musdb dataset from the XP arguments."""
|
||||
sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8]
|
||||
metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json")
|
||||
root = Path(args.musdb) / "train"
|
||||
if not metadata_file.is_file() and distrib.rank == 0:
|
||||
metadata_file.parent.mkdir(exist_ok=True, parents=True)
|
||||
metadata = build_metadata(root, args.sources)
|
||||
json.dump(metadata, open(metadata_file, "w"))
|
||||
if distrib.world_size > 1:
|
||||
distributed.barrier()
|
||||
metadata = json.load(open(metadata_file))
|
||||
|
||||
valid_tracks = _get_musdb_valid()
|
||||
if args.train_valid:
|
||||
metadata_train = metadata
|
||||
else:
|
||||
metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks}
|
||||
metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks}
|
||||
if args.full_cv:
|
||||
kw_cv = {}
|
||||
else:
|
||||
kw_cv = {'segment': args.segment, 'shift': args.shift}
|
||||
train_set = Wavset(root, metadata_train, args.sources,
|
||||
segment=args.segment, shift=args.shift,
|
||||
samplerate=args.samplerate, channels=args.channels,
|
||||
normalize=args.normalize)
|
||||
valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources),
|
||||
samplerate=args.samplerate, channels=args.channels,
|
||||
normalize=args.normalize, **kw_cv)
|
||||
return train_set, valid_set
|
||||
Reference in New Issue
Block a user