Files
audioSeparator/demucs/pretrained.py

98 lines
3.2 KiB
Python
Raw Normal View History

2025-04-21 13:41:19 +02:00
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loading pretrained models.
"""
import logging
from pathlib import Path
import typing as tp
from log import log_step
from dora.log import fatal, bold
from .hdemucs import HDemucs
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa
from .states import _check_diffq
logger = logging.getLogger(__name__)
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/"
REMOTE_ROOT = Path(__file__).parent / 'remote'
SOURCES = ["drums", "bass", "other", "vocals"]
DEFAULT_MODEL = 'htdemucs'
def demucs_unittest():
model = HDemucs(channels=4, sources=SOURCES)
return model
def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
group.add_argument("-n", "--name", default=None,
help="Pretrained model name or signature. Default is htdemucs.")
parser.add_argument("--repo", type=Path,
help="Folder containing all pre-trained models for use with -n.")
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
root: str = ''
models: tp.Dict[str, str] = {}
for line in remote_file_list.read_text().split('\n'):
line = line.strip()
if line.startswith('#'):
continue
elif line.startswith('root:'):
root = line.split(':', 1)[1].strip()
else:
sig = line.split('-', 1)[0]
assert sig not in models
models[sig] = ROOT_URL + root + line
return models
def get_model(name: str,
repo: tp.Optional[Path] = None):
"""`name` must be a bag of models name or a pretrained signature
from the remote AWS model repo or the specified local repo if `repo` is not None.
"""
if name == 'demucs_unittest':
return demucs_unittest()
model_repo: ModelOnlyRepo
if repo is None:
models = _parse_remote_files(REMOTE_ROOT / 'files.txt')
model_repo = RemoteRepo(models)
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
else:
if not repo.is_dir():
fatal(f"{repo} must exist and be a directory.")
model_repo = LocalRepo(repo)
bag_repo = BagOnlyRepo(repo, model_repo)
any_repo = AnyModelRepo(model_repo, bag_repo)
try:
model = any_repo.get_model(name)
except ImportError as exc:
if 'diffq' in exc.args[0]:
_check_diffq()
raise
model.eval()
return model
def get_model_from_args(args):
"""
Load local model package or pre-trained model.
"""
if args.name is None:
args.name = DEFAULT_MODEL
log_step("warning", 100, "Important: the default model was recently changed to `htdemucs`."
"the latest Hybrid Transformer Demucs model. In some cases, this model can "
"actually perform worse than previous models. To get back the old default model "
"use `-n mdx_extra_q`.")
return get_model(name=args.name, repo=args.repo)