diff --git a/.gitignore b/.gitignore index 683bdfcf0a..1316ec4391 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ rmvpe.pt # Generated by RVC /logs /weights +/dataset # To set a Python version for the project .tool-versions diff --git a/infer/lib/train/utils.py b/infer/lib/train/utils.py index 765c54c61d..4f3b7886ed 100644 --- a/infer/lib/train/utils.py +++ b/infer/lib/train/utils.py @@ -1,483 +1,478 @@ -import argparse -import glob -import json -import logging -import os -import subprocess -import sys -import shutil - -import numpy as np -import torch -from scipy.io.wavfile import read - -MATPLOTLIB_FLAG = False - -logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) -logger = logging - - -def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1): - assert os.path.isfile(checkpoint_path) - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") - - ################## - def go(model, bkey): - saved_state_dict = checkpoint_dict[bkey] - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - new_state_dict = {} - for k, v in state_dict.items(): # 模型需要的shape - try: - new_state_dict[k] = saved_state_dict[k] - if saved_state_dict[k].shape != state_dict[k].shape: - logger.warning( - "shape-%s-mismatch. need: %s, get: %s", - k, - state_dict[k].shape, - saved_state_dict[k].shape, - ) # - raise KeyError - except: - # logger.info(traceback.format_exc()) - logger.info("%s is not in the checkpoint", k) # pretrain缺失的 - new_state_dict[k] = v # 模型自带的随机值 - if hasattr(model, "module"): - model.module.load_state_dict(new_state_dict, strict=False) - else: - model.load_state_dict(new_state_dict, strict=False) - return model - - go(combd, "combd") - model = go(sbd, "sbd") - ############# - logger.info("Loaded model weights") - - iteration = checkpoint_dict["iteration"] - learning_rate = checkpoint_dict["learning_rate"] - if ( - optimizer is not None and load_opt == 1 - ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch - # try: - optimizer.load_state_dict(checkpoint_dict["optimizer"]) - # except: - # traceback.print_exc() - logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) - return model, optimizer, learning_rate, iteration - - -# def load_checkpoint(checkpoint_path, model, optimizer=None): -# assert os.path.isfile(checkpoint_path) -# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') -# iteration = checkpoint_dict['iteration'] -# learning_rate = checkpoint_dict['learning_rate'] -# if optimizer is not None: -# optimizer.load_state_dict(checkpoint_dict['optimizer']) -# # print(1111) -# saved_state_dict = checkpoint_dict['model'] -# # print(1111) -# -# if hasattr(model, 'module'): -# state_dict = model.module.state_dict() -# else: -# state_dict = model.state_dict() -# new_state_dict= {} -# for k, v in state_dict.items(): -# try: -# new_state_dict[k] = saved_state_dict[k] -# except: -# logger.info("%s is not in the checkpoint" % k) -# new_state_dict[k] = v -# if hasattr(model, 'module'): -# model.module.load_state_dict(new_state_dict) -# else: -# model.load_state_dict(new_state_dict) -# logger.info("Loaded checkpoint '{}' (epoch {})" .format( -# checkpoint_path, iteration)) -# return model, optimizer, learning_rate, iteration -def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): - assert os.path.isfile(checkpoint_path) - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") - - saved_state_dict = checkpoint_dict["model"] - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - new_state_dict = {} - for k, v in state_dict.items(): # 模型需要的shape - try: - new_state_dict[k] = saved_state_dict[k] - if saved_state_dict[k].shape != state_dict[k].shape: - logger.warning( - "shape-%s-mismatch|need-%s|get-%s", - k, - state_dict[k].shape, - saved_state_dict[k].shape, - ) # - raise KeyError - except: - # logger.info(traceback.format_exc()) - logger.info("%s is not in the checkpoint", k) # pretrain缺失的 - new_state_dict[k] = v # 模型自带的随机值 - if hasattr(model, "module"): - model.module.load_state_dict(new_state_dict, strict=False) - else: - model.load_state_dict(new_state_dict, strict=False) - logger.info("Loaded model weights") - - iteration = checkpoint_dict["iteration"] - learning_rate = checkpoint_dict["learning_rate"] - if ( - optimizer is not None and load_opt == 1 - ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch - # try: - optimizer.load_state_dict(checkpoint_dict["optimizer"]) - # except: - # traceback.print_exc() - logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) - return model, optimizer, learning_rate, iteration - - -def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): - logger.info( - "Saving model and optimizer state at epoch {} to {}".format( - iteration, checkpoint_path - ) - ) - if hasattr(model, "module"): - state_dict = model.module.state_dict() - else: - state_dict = model.state_dict() - torch.save( - { - "model": state_dict, - "iteration": iteration, - "optimizer": optimizer.state_dict(), - "learning_rate": learning_rate, - }, - checkpoint_path, - ) - - -def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path): - logger.info( - "Saving model and optimizer state at epoch {} to {}".format( - iteration, checkpoint_path - ) - ) - if hasattr(combd, "module"): - state_dict_combd = combd.module.state_dict() - else: - state_dict_combd = combd.state_dict() - if hasattr(sbd, "module"): - state_dict_sbd = sbd.module.state_dict() - else: - state_dict_sbd = sbd.state_dict() - torch.save( - { - "combd": state_dict_combd, - "sbd": state_dict_sbd, - "iteration": iteration, - "optimizer": optimizer.state_dict(), - "learning_rate": learning_rate, - }, - checkpoint_path, - ) - - -def summarize( - writer, - global_step, - scalars={}, - histograms={}, - images={}, - audios={}, - audio_sampling_rate=22050, -): - for k, v in scalars.items(): - writer.add_scalar(k, v, global_step) - for k, v in histograms.items(): - writer.add_histogram(k, v, global_step) - for k, v in images.items(): - writer.add_image(k, v, global_step, dataformats="HWC") - for k, v in audios.items(): - writer.add_audio(k, v, global_step, audio_sampling_rate) - - -def latest_checkpoint_path(dir_path, regex="G_*.pth"): - f_list = glob.glob(os.path.join(dir_path, regex)) - f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) - x = f_list[-1] - logger.debug(x) - return x - - -def plot_spectrogram_to_numpy(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -def plot_alignment_to_numpy(alignment, info=None): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(6, 4)) - im = ax.imshow( - alignment.transpose(), aspect="auto", origin="lower", interpolation="none" - ) - fig.colorbar(im, ax=ax) - xlabel = "Decoder timestep" - if info is not None: - xlabel += "\n\n" + info - plt.xlabel(xlabel) - plt.ylabel("Encoder timestep") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -def load_wav_to_torch(full_path): - sampling_rate, data = read(full_path) - return torch.FloatTensor(data.astype(np.float32)), sampling_rate - - -def load_filepaths_and_text(filename, split="|"): - try: - with open(filename, encoding="utf-8") as f: - filepaths_and_text = [line.strip().split(split) for line in f] - except UnicodeDecodeError: - with open(filename) as f: - filepaths_and_text = [line.strip().split(split) for line in f] - - return filepaths_and_text - - -def get_hparams(init=True): - """ - todo: - 结尾七人组: - 保存频率、总epoch done - bs done - pretrainG、pretrainD done - 卡号:os.en["CUDA_VISIBLE_DEVICES"] done - if_latest done - 模型:if_f0 done - 采样率:自动选择config done - 是否缓存数据集进GPU:if_cache_data_in_gpu done - - -m: - 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done - -c不要了 - """ - parser = argparse.ArgumentParser() - parser.add_argument( - "-se", - "--save_every_epoch", - type=int, - required=True, - help="checkpoint save frequency (epoch)", - ) - parser.add_argument( - "-te", "--total_epoch", type=int, required=True, help="total_epoch" - ) - parser.add_argument( - "-pg", "--pretrainG", type=str, default="", help="Pretrained Generator path" - ) - parser.add_argument( - "-pd", "--pretrainD", type=str, default="", help="Pretrained Discriminator path" - ) - parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -") - parser.add_argument( - "-bs", "--batch_size", type=int, required=True, help="batch size" - ) - parser.add_argument( - "-e", "--experiment_dir", type=str, required=True, help="experiment dir" - ) # -m - parser.add_argument( - "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k" - ) - parser.add_argument( - "-sw", - "--save_every_weights", - type=str, - default="0", - help="save the extracted model in weights directory when saving checkpoints", - ) - parser.add_argument( - "-v", "--version", type=str, required=True, help="model version" - ) - parser.add_argument( - "-f0", - "--if_f0", - type=int, - required=True, - help="use f0 as one of the inputs of the model, 1 or 0", - ) - parser.add_argument( - "-l", - "--if_latest", - type=int, - required=True, - help="if only save the latest G/D pth file, 1 or 0", - ) - parser.add_argument( - "-c", - "--if_cache_data_in_gpu", - type=int, - required=True, - help="if caching the dataset in GPU memory, 1 or 0", - ) - - args = parser.parse_args() - name = args.experiment_dir - experiment_dir = os.path.join("./logs", args.experiment_dir) - - config_save_path = os.path.join(experiment_dir, "config.json") - with open(config_save_path, "r") as f: - config = json.load(f) - - hparams = HParams(**config) - hparams.model_dir = hparams.experiment_dir = experiment_dir - hparams.save_every_epoch = args.save_every_epoch - hparams.name = name - hparams.total_epoch = args.total_epoch - hparams.pretrainG = args.pretrainG - hparams.pretrainD = args.pretrainD - hparams.version = args.version - hparams.gpus = args.gpus - hparams.train.batch_size = args.batch_size - hparams.sample_rate = args.sample_rate - hparams.if_f0 = args.if_f0 - hparams.if_latest = args.if_latest - hparams.save_every_weights = args.save_every_weights - hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu - hparams.data.training_files = "%s/filelist.txt" % experiment_dir - return hparams - - -def get_hparams_from_dir(model_dir): - config_save_path = os.path.join(model_dir, "config.json") - with open(config_save_path, "r") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - hparams.model_dir = model_dir - return hparams - - -def get_hparams_from_file(config_path): - with open(config_path, "r") as f: - data = f.read() - config = json.loads(data) - - hparams = HParams(**config) - return hparams - - -def check_git_hash(model_dir): - source_dir = os.path.dirname(os.path.realpath(__file__)) - if not os.path.exists(os.path.join(source_dir, ".git")): - logger.warning( - "{} is not a git repository, therefore hash value comparison will be ignored.".format( - source_dir - ) - ) - return - - cur_hash = subprocess.getoutput("git rev-parse HEAD") - - path = os.path.join(model_dir, "githash") - if os.path.exists(path): - saved_hash = open(path).read() - if saved_hash != cur_hash: - logger.warning( - "git hash values are different. {}(saved) != {}(current)".format( - saved_hash[:8], cur_hash[:8] - ) - ) - else: - open(path, "w").write(cur_hash) - - -def get_logger(model_dir, filename="train.log"): - global logger - logger = logging.getLogger(os.path.basename(model_dir)) - logger.setLevel(logging.DEBUG) - - formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") - if not os.path.exists(model_dir): - os.makedirs(model_dir) - h = logging.FileHandler(os.path.join(model_dir, filename)) - h.setLevel(logging.DEBUG) - h.setFormatter(formatter) - logger.addHandler(h) - return logger - - -class HParams: - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if type(v) == dict: - v = HParams(**v) - self[k] = v - - def keys(self): - return self.__dict__.keys() - - def items(self): - return self.__dict__.items() - - def values(self): - return self.__dict__.values() - - def __len__(self): - return len(self.__dict__) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - return setattr(self, key, value) - - def __contains__(self, key): - return key in self.__dict__ - - def __repr__(self): - return self.__dict__.__repr__() +import argparse +import glob +import json +import logging +import os +import subprocess +import sys +import shutil + +import numpy as np +import torch +from scipy.io.wavfile import read + +MATPLOTLIB_FLAG = False + +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) +logger = logging + + +def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + ################## + def go(model, bkey): + saved_state_dict = checkpoint_dict[bkey] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + logger.warning( + "shape-%s-mismatch. need: %s, get: %s", + k, + state_dict[k].shape, + saved_state_dict[k].shape, + ) # + raise KeyError + except: + # logger.info(traceback.format_exc()) + logger.info("%s is not in the checkpoint", k) # pretrain缺失的 + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + return model + + go(combd, "combd") + model = go(sbd, "sbd") + ############# + logger.info("Loaded model weights") + + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None and load_opt == 1 + ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + # try: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + # except: + # traceback.print_exc() + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +# def load_checkpoint(checkpoint_path, model, optimizer=None): +# assert os.path.isfile(checkpoint_path) +# checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') +# iteration = checkpoint_dict['iteration'] +# learning_rate = checkpoint_dict['learning_rate'] +# if optimizer is not None: +# optimizer.load_state_dict(checkpoint_dict['optimizer']) +# # print(1111) +# saved_state_dict = checkpoint_dict['model'] +# # print(1111) +# +# if hasattr(model, 'module'): +# state_dict = model.module.state_dict() +# else: +# state_dict = model.state_dict() +# new_state_dict= {} +# for k, v in state_dict.items(): +# try: +# new_state_dict[k] = saved_state_dict[k] +# except: +# logger.info("%s is not in the checkpoint" % k) +# new_state_dict[k] = v +# if hasattr(model, 'module'): +# model.module.load_state_dict(new_state_dict) +# else: +# model.load_state_dict(new_state_dict) +# logger.info("Loaded checkpoint '{}' (epoch {})" .format( +# checkpoint_path, iteration)) +# return model, optimizer, learning_rate, iteration +def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") + + saved_state_dict = checkpoint_dict["model"] + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict = {} + for k, v in state_dict.items(): # 模型需要的shape + try: + new_state_dict[k] = saved_state_dict[k] + if saved_state_dict[k].shape != state_dict[k].shape: + logger.warning( + "shape-%s-mismatch|need-%s|get-%s", + k, + state_dict[k].shape, + saved_state_dict[k].shape, + ) # + raise KeyError + except: + # logger.info(traceback.format_exc()) + logger.info("%s is not in the checkpoint", k) # pretrain缺失的 + new_state_dict[k] = v # 模型自带的随机值 + if hasattr(model, "module"): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + logger.info("Loaded model weights") + + iteration = checkpoint_dict["iteration"] + learning_rate = checkpoint_dict["learning_rate"] + if ( + optimizer is not None and load_opt == 1 + ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch + # try: + optimizer.load_state_dict(checkpoint_dict["optimizer"]) + # except: + # traceback.print_exc() + logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at epoch {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(model, "module"): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save( + { + "model": state_dict, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path): + logger.info( + "Saving model and optimizer state at epoch {} to {}".format( + iteration, checkpoint_path + ) + ) + if hasattr(combd, "module"): + state_dict_combd = combd.module.state_dict() + else: + state_dict_combd = combd.state_dict() + if hasattr(sbd, "module"): + state_dict_sbd = sbd.module.state_dict() + else: + state_dict_sbd = sbd.state_dict() + torch.save( + { + "combd": state_dict_combd, + "sbd": state_dict_sbd, + "iteration": iteration, + "optimizer": optimizer.state_dict(), + "learning_rate": learning_rate, + }, + checkpoint_path, + ) + + +def summarize( + writer, + global_step, + scalars={}, + histograms={}, + images={}, + audios={}, + audio_sampling_rate=22050, +): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats="HWC") + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + logger.debug(x) + return x + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow( + alignment.transpose(), aspect="auto", origin="lower", interpolation="none" + ) + fig.colorbar(im, ax=ax) + xlabel = "Decoder timestep" + if info is not None: + xlabel += "\n\n" + info + plt.xlabel(xlabel) + plt.ylabel("Encoder timestep") + plt.tight_layout() + + fig.canvas.draw() + data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] + plt.close() + return data + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + + +def load_filepaths_and_text(filename, split="|"): + with open(filename, encoding="utf-8") as f: + filepaths_and_text = [line.strip().split(split) for line in f] + return filepaths_and_text + + +def get_hparams(init=True): + """ + todo: + 结尾七人组: + 保存频率、总epoch done + bs done + pretrainG、pretrainD done + 卡号:os.en["CUDA_VISIBLE_DEVICES"] done + if_latest done + 模型:if_f0 done + 采样率:自动选择config done + 是否缓存数据集进GPU:if_cache_data_in_gpu done + + -m: + 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done + -c不要了 + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "-se", + "--save_every_epoch", + type=int, + required=True, + help="checkpoint save frequency (epoch)", + ) + parser.add_argument( + "-te", "--total_epoch", type=int, required=True, help="total_epoch" + ) + parser.add_argument( + "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path" + ) + parser.add_argument( + "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path" + ) + parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -") + parser.add_argument( + "-bs", "--batch_size", type=int, required=True, help="batch size" + ) + parser.add_argument( + "-e", "--experiment_dir", type=str, required=True, help="experiment dir" + ) # -m + parser.add_argument( + "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k" + ) + parser.add_argument( + "-sw", + "--save_every_weights", + type=str, + default="0", + help="save the extracted model in weights directory when saving checkpoints", + ) + parser.add_argument( + "-v", "--version", type=str, required=True, help="model version" + ) + parser.add_argument( + "-f0", + "--if_f0", + type=int, + required=True, + help="use f0 as one of the inputs of the model, 1 or 0", + ) + parser.add_argument( + "-l", + "--if_latest", + type=int, + required=True, + help="if only save the latest G/D pth file, 1 or 0", + ) + parser.add_argument( + "-c", + "--if_cache_data_in_gpu", + type=int, + required=True, + help="if caching the dataset in GPU memory, 1 or 0", + ) + + args = parser.parse_args() + name = args.experiment_dir + experiment_dir = os.path.join("./logs", args.experiment_dir) + + config_save_path = os.path.join(experiment_dir, "config.json") + with open(config_save_path, "r") as f: + config = json.load(f) + + hparams = HParams(**config) + hparams.model_dir = hparams.experiment_dir = experiment_dir + hparams.save_every_epoch = args.save_every_epoch + hparams.name = name + hparams.total_epoch = args.total_epoch + hparams.pretrainG = args.pretrainG + hparams.pretrainD = args.pretrainD + hparams.version = args.version + hparams.gpus = args.gpus + hparams.train.batch_size = args.batch_size + hparams.sample_rate = args.sample_rate + hparams.if_f0 = args.if_f0 + hparams.if_latest = args.if_latest + hparams.save_every_weights = args.save_every_weights + hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu + hparams.data.training_files = "%s/filelist.txt" % experiment_dir + return hparams + + +def get_hparams_from_dir(model_dir): + config_save_path = os.path.join(model_dir, "config.json") + with open(config_save_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + hparams.model_dir = model_dir + return hparams + + +def get_hparams_from_file(config_path): + with open(config_path, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def check_git_hash(model_dir): + source_dir = os.path.dirname(os.path.realpath(__file__)) + if not os.path.exists(os.path.join(source_dir, ".git")): + logger.warning( + "{} is not a git repository, therefore hash value comparison will be ignored.".format( + source_dir + ) + ) + return + + cur_hash = subprocess.getoutput("git rev-parse HEAD") + + path = os.path.join(model_dir, "githash") + if os.path.exists(path): + saved_hash = open(path).read() + if saved_hash != cur_hash: + logger.warning( + "git hash values are different. {}(saved) != {}(current)".format( + saved_hash[:8], cur_hash[:8] + ) + ) + else: + open(path, "w").write(cur_hash) + + +def get_logger(model_dir, filename="train.log"): + global logger + logger = logging.getLogger(os.path.basename(model_dir)) + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + h = logging.FileHandler(os.path.join(model_dir, filename)) + h.setLevel(logging.DEBUG) + h.setFormatter(formatter) + logger.addHandler(h) + return logger + + +class HParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() diff --git a/infer/modules/train/train.py b/infer/modules/train/train.py index 38a5678282..e7489dd0f2 100644 --- a/infer/modules/train/train.py +++ b/infer/modules/train/train.py @@ -18,6 +18,14 @@ import torch +# PyTorch 2.6+ defaults weights_only=True which breaks legacy model loading +_original_torch_load = torch.load +def _patched_torch_load(*args, **kwargs): + if "weights_only" not in kwargs: + kwargs["weights_only"] = False + return _original_torch_load(*args, **kwargs) +torch.load = _patched_torch_load + try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import @@ -44,6 +52,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm from infer.lib.infer_pack import commons from infer.lib.train.data_utils import ( @@ -103,12 +112,15 @@ def main(): n_gpus = 1 os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) + if n_gpus == 1: + # Skip subprocess spawn for single-GPU — avoids silent Windows mp crashes + run(0, 1, hps) + return children = [] - logger = utils.get_logger(hps.model_dir) for i in range(n_gpus): subproc = mp.Process( target=run, - args=(i, n_gpus, hps, logger), + args=(i, n_gpus, hps), ) children.append(subproc) subproc.start() @@ -117,18 +129,23 @@ def main(): children[i].join() -def run(rank, n_gpus, hps, logger: logging.Logger): +def run( + rank, + n_gpus, + hps, +): global global_step if rank == 0: - # logger = utils.get_logger(hps.model_dir) + logger = utils.get_logger(hps.model_dir) logger.info(hps) # utils.check_git_hash(hps.model_dir) writer = SummaryWriter(log_dir=hps.model_dir) writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) - dist.init_process_group( - backend="gloo", init_method="env://", world_size=n_gpus, rank=rank - ) + if n_gpus > 1: + dist.init_process_group( + backend="gloo", init_method="env://", world_size=n_gpus, rank=rank + ) torch.manual_seed(hps.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) @@ -154,13 +171,11 @@ def run(rank, n_gpus, hps, logger: logging.Logger): collate_fn = TextAudioCollate() train_loader = DataLoader( train_dataset, - num_workers=4, + num_workers=0, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, - persistent_workers=True, - prefetch_factor=8, ) if hps.if_f0 == 1: net_g = RVC_Model_f0( @@ -196,14 +211,15 @@ def run(rank, n_gpus, hps, logger: logging.Logger): ) # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True) # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True) - if hasattr(torch, "xpu") and torch.xpu.is_available(): - pass - elif torch.cuda.is_available(): - net_g = DDP(net_g, device_ids=[rank]) - net_d = DDP(net_d, device_ids=[rank]) - else: - net_g = DDP(net_g) - net_d = DDP(net_d) + if n_gpus > 1: + if hasattr(torch, "xpu") and torch.xpu.is_available(): + pass + elif torch.cuda.is_available(): + net_g = DDP(net_g, device_ids=[rank]) + net_d = DDP(net_d, device_ids=[rank]) + else: + net_g = DDP(net_g) + net_d = DDP(net_d) try: # 如果能加载自动resume _, _, _, epoch_str = utils.load_checkpoint( @@ -263,35 +279,46 @@ def run(rank, n_gpus, hps, logger: logging.Logger): scaler = GradScaler(enabled=hps.train.fp16_run) cache = [] - for epoch in range(epoch_str, hps.train.epochs + 1): - if rank == 0: - train_and_evaluate( - rank, - epoch, - hps, - [net_g, net_d], - [optim_g, optim_d], - [scheduler_g, scheduler_d], - scaler, - [train_loader, None], - logger, - [writer, writer_eval], - cache, - ) - else: - train_and_evaluate( - rank, - epoch, - hps, - [net_g, net_d], - [optim_g, optim_d], - [scheduler_g, scheduler_d], - scaler, - [train_loader, None], - None, - None, - cache, - ) + epoch_bar = tqdm( + range(epoch_str, hps.total_epoch + 1), + desc="Training", + unit="epoch", + disable=rank != 0, + ) + for epoch in epoch_bar: + try: + if rank == 0: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + logger, + [writer, writer_eval], + cache, + ) + else: + train_and_evaluate( + rank, + epoch, + hps, + [net_g, net_d], + [optim_g, optim_d], + [scheduler_g, scheduler_d], + scaler, + [train_loader, None], + None, + None, + cache, + ) + except Exception as e: + if rank == 0: + logger.error("Training error at epoch %d: %s", epoch, e, exc_info=True) + raise scheduler_g.step() scheduler_d.step() @@ -395,7 +422,15 @@ def train_and_evaluate( # Run steps epoch_recorder = EpochRecorder() - for batch_idx, info in data_iterator: + if rank == 0: + if hps.if_cache_data_in_gpu == True: + total_batches = len(cache) if cache else len(train_loader) + else: + total_batches = len(train_loader) + batch_bar = tqdm(data_iterator, total=total_batches, desc=f"Epoch {epoch}", leave=False) + else: + batch_bar = data_iterator + for batch_idx, info in batch_bar: # Data ## Unpack if hps.if_f0 == 1: @@ -500,6 +535,12 @@ def train_and_evaluate( scaler.update() if rank == 0: + batch_bar.set_postfix( + disc=f"{loss_disc:.3f}", + gen=f"{loss_gen:.3f}", + mel=f"{loss_mel:.3f}", + kl=f"{loss_kl:.3f}", + ) if global_step % hps.train.log_interval == 0: lr = optim_g.param_groups[0]["lr"] logger.info( diff --git a/infer/modules/vc/utils.py b/infer/modules/vc/utils.py index c128707cfd..db5e9b155d 100644 --- a/infer/modules/vc/utils.py +++ b/infer/modules/vc/utils.py @@ -1,5 +1,15 @@ import os +import torch + +# PyTorch 2.6+ defaults weights_only=True which breaks fairseq's checkpoint loading +_orig_torch_load = torch.load +def _patched_torch_load(*args, **kwargs): + if "weights_only" not in kwargs: + kwargs["weights_only"] = False + return _orig_torch_load(*args, **kwargs) +torch.load = _patched_torch_load + from fairseq import checkpoint_utils diff --git a/requirements-amd.txt b/requirements-amd.txt index ee8fa374c9..e16fd0d688 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -5,7 +5,7 @@ numpy==1.23.5 scipy librosa==0.10.2 llvmlite==0.39.0 -fairseq==0.12.2 +fairseq @ git+https://github.com/One-sixth/fairseq.git faiss-cpu==1.7.3 gradio==3.34.0 Cython diff --git a/requirements-dml.txt b/requirements-dml.txt index 6987607703..bcaa109697 100644 --- a/requirements-dml.txt +++ b/requirements-dml.txt @@ -4,7 +4,7 @@ numpy==1.23.5 scipy librosa==0.10.2 llvmlite==0.39.0 -fairseq==0.12.2 +fairseq @ git+https://github.com/One-sixth/fairseq.git faiss-cpu==1.7.3 gradio==3.34.0 Cython diff --git a/requirements-ipex.txt b/requirements-ipex.txt index b59bdcbadf..1e1deb6ff0 100644 --- a/requirements-ipex.txt +++ b/requirements-ipex.txt @@ -9,7 +9,7 @@ numpy==1.23.5 scipy librosa==0.10.2 llvmlite==0.39.0 -fairseq==0.12.2 +fairseq @ git+https://github.com/One-sixth/fairseq.git faiss-cpu==1.7.3 gradio==3.34.0 Cython diff --git a/requirements.txt b/requirements.txt index 28635b8a93..41273db2b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ numpy==1.23.5 scipy librosa==0.9.1 llvmlite==0.39.0 -fairseq==0.12.2 +fairseq @ git+https://github.com/One-sixth/fairseq.git faiss-cpu==1.7.3 gradio==3.34.0 Cython