generated from thinkode/modelRepository
Initial commit and v1.0
This commit is contained in:
86
demucs/repitch.py
Normal file
86
demucs/repitch.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""Utility for on the fly pitch/tempo change for data augmentation."""
|
||||
|
||||
import random
|
||||
import subprocess as sp
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
|
||||
from .audio import save_audio
|
||||
|
||||
|
||||
class RepitchedWrapper:
|
||||
"""
|
||||
Wrap a dataset to apply online change of pitch / tempo.
|
||||
"""
|
||||
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12,
|
||||
tempo_std=5, vocals=[3], same=True):
|
||||
self.dataset = dataset
|
||||
self.proba = proba
|
||||
self.max_pitch = max_pitch
|
||||
self.max_tempo = max_tempo
|
||||
self.tempo_std = tempo_std
|
||||
self.same = same
|
||||
self.vocals = vocals
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
streams = self.dataset[index]
|
||||
in_length = streams.shape[-1]
|
||||
out_length = int((1 - 0.01 * self.max_tempo) * in_length)
|
||||
|
||||
if random.random() < self.proba:
|
||||
outs = []
|
||||
for idx, stream in enumerate(streams):
|
||||
if idx == 0 or not self.same:
|
||||
delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
|
||||
delta_tempo = random.gauss(0, self.tempo_std)
|
||||
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
|
||||
stream = repitch(
|
||||
stream,
|
||||
delta_pitch,
|
||||
delta_tempo,
|
||||
voice=idx in self.vocals)
|
||||
outs.append(stream[:, :out_length])
|
||||
streams = torch.stack(outs)
|
||||
else:
|
||||
streams = streams[..., :out_length]
|
||||
return streams
|
||||
|
||||
|
||||
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
|
||||
"""
|
||||
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
|
||||
pitch is in semi tones.
|
||||
Requires `soundstretch` to be installed, see
|
||||
https://www.surina.net/soundtouch/soundstretch.html
|
||||
"""
|
||||
infile = tempfile.NamedTemporaryFile(suffix=".wav")
|
||||
outfile = tempfile.NamedTemporaryFile(suffix=".wav")
|
||||
save_audio(wav, infile.name, samplerate, clip='clamp')
|
||||
command = [
|
||||
"soundstretch",
|
||||
infile.name,
|
||||
outfile.name,
|
||||
f"-pitch={pitch}",
|
||||
f"-tempo={tempo:.6f}",
|
||||
]
|
||||
if quick:
|
||||
command += ["-quick"]
|
||||
if voice:
|
||||
command += ["-speech"]
|
||||
try:
|
||||
sp.run(command, capture_output=True, check=True)
|
||||
except sp.CalledProcessError as error:
|
||||
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
|
||||
wav, sr = ta.load(outfile.name)
|
||||
assert sr == samplerate
|
||||
return wav
|
||||
Reference in New Issue
Block a user