generated from thinkode/modelRepository
87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||
|
|
# All rights reserved.
|
||
|
|
#
|
||
|
|
# This source code is licensed under the license found in the
|
||
|
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
"""Utility for on the fly pitch/tempo change for data augmentation."""
|
||
|
|
|
||
|
|
import random
|
||
|
|
import subprocess as sp
|
||
|
|
import tempfile
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torchaudio as ta
|
||
|
|
|
||
|
|
from .audio import save_audio
|
||
|
|
|
||
|
|
|
||
|
|
class RepitchedWrapper:
|
||
|
|
"""
|
||
|
|
Wrap a dataset to apply online change of pitch / tempo.
|
||
|
|
"""
|
||
|
|
def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12,
|
||
|
|
tempo_std=5, vocals=[3], same=True):
|
||
|
|
self.dataset = dataset
|
||
|
|
self.proba = proba
|
||
|
|
self.max_pitch = max_pitch
|
||
|
|
self.max_tempo = max_tempo
|
||
|
|
self.tempo_std = tempo_std
|
||
|
|
self.same = same
|
||
|
|
self.vocals = vocals
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.dataset)
|
||
|
|
|
||
|
|
def __getitem__(self, index):
|
||
|
|
streams = self.dataset[index]
|
||
|
|
in_length = streams.shape[-1]
|
||
|
|
out_length = int((1 - 0.01 * self.max_tempo) * in_length)
|
||
|
|
|
||
|
|
if random.random() < self.proba:
|
||
|
|
outs = []
|
||
|
|
for idx, stream in enumerate(streams):
|
||
|
|
if idx == 0 or not self.same:
|
||
|
|
delta_pitch = random.randint(-self.max_pitch, self.max_pitch)
|
||
|
|
delta_tempo = random.gauss(0, self.tempo_std)
|
||
|
|
delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo)
|
||
|
|
stream = repitch(
|
||
|
|
stream,
|
||
|
|
delta_pitch,
|
||
|
|
delta_tempo,
|
||
|
|
voice=idx in self.vocals)
|
||
|
|
outs.append(stream[:, :out_length])
|
||
|
|
streams = torch.stack(outs)
|
||
|
|
else:
|
||
|
|
streams = streams[..., :out_length]
|
||
|
|
return streams
|
||
|
|
|
||
|
|
|
||
|
|
def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100):
|
||
|
|
"""
|
||
|
|
tempo is a relative delta in percentage, so tempo=10 means tempo at 110%!
|
||
|
|
pitch is in semi tones.
|
||
|
|
Requires `soundstretch` to be installed, see
|
||
|
|
https://www.surina.net/soundtouch/soundstretch.html
|
||
|
|
"""
|
||
|
|
infile = tempfile.NamedTemporaryFile(suffix=".wav")
|
||
|
|
outfile = tempfile.NamedTemporaryFile(suffix=".wav")
|
||
|
|
save_audio(wav, infile.name, samplerate, clip='clamp')
|
||
|
|
command = [
|
||
|
|
"soundstretch",
|
||
|
|
infile.name,
|
||
|
|
outfile.name,
|
||
|
|
f"-pitch={pitch}",
|
||
|
|
f"-tempo={tempo:.6f}",
|
||
|
|
]
|
||
|
|
if quick:
|
||
|
|
command += ["-quick"]
|
||
|
|
if voice:
|
||
|
|
command += ["-speech"]
|
||
|
|
try:
|
||
|
|
sp.run(command, capture_output=True, check=True)
|
||
|
|
except sp.CalledProcessError as error:
|
||
|
|
raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}")
|
||
|
|
wav, sr = ta.load(outfile.name)
|
||
|
|
assert sr == samplerate
|
||
|
|
return wav
|