2023-10-28 21:16:43 +08:00

519 lines
16 KiB
Python

# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
# Copyright 2021 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Function to get random segments."""
from typing import Any, Dict, List, Optional, Tuple, Union
import collections
import logging
import re
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.distributed as dist
from lhotse.dataset.sampling.base import CutSampler
from pathlib import Path
from phonemizer import phonemize
from symbols import symbol_table
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter
from unidecode import unidecode
def get_random_segments(
x: torch.Tensor,
x_lengths: torch.Tensor,
segment_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = x.size()
max_start_idx = x_lengths - segment_size
max_start_idx[max_start_idx < 0] = 0
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
dtype=torch.long,
)
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
def get_segments(
x: torch.Tensor,
start_idxs: torch.Tensor,
segment_size: int,
) -> torch.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = x.size()
segments = x.new_zeros(b, c, segment_size)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx : start_idx + segment_size]
return segments
# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py
def force_gatherable(data, device):
"""Change object to gatherable in torch.nn.DataParallel recursively
The difference from to_device() is changing to torch.Tensor if float or int
value is found.
The restriction to the returned value in DataParallel:
The object must be
- torch.cuda.Tensor
- 1 or more dimension. 0-dimension-tensor sends warning.
or a list, tuple, dict.
"""
if isinstance(data, dict):
return {k: force_gatherable(v, device) for k, v in data.items()}
# DataParallel can't handle NamedTuple well
elif isinstance(data, tuple) and type(data) is not tuple:
return type(data)(*[force_gatherable(o, device) for o in data])
elif isinstance(data, (list, tuple, set)):
return type(data)(force_gatherable(v, device) for v in data)
elif isinstance(data, np.ndarray):
return force_gatherable(torch.from_numpy(data), device)
elif isinstance(data, torch.Tensor):
if data.dim() == 0:
# To 1-dim array
data = data[None]
return data.to(device)
elif isinstance(data, float):
return torch.tensor([data], dtype=torch.float, device=device)
elif isinstance(data, int):
return torch.tensor([data], dtype=torch.long, device=device)
elif data is None:
return None
else:
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
return data
# The following codes are based on https://github.com/jaywalnut310/vits
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def text_clean(text):
'''Pipeline for English text, including abbreviation expansion. + punctuation + stress.
Returns:
A string of phonemes.
'''
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
phonemes = phonemize(
text,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
phonemes = collapse_whitespace(phonemes)
return phonemes
# Mappings from symbol to numeric ID and vice versa:
symbol_to_id = {s: i for i, s in enumerate(symbol_table)}
id_to_symbol = {i: s for i, s in enumerate(symbol_table)}
# def text_to_sequence(text: str) -> List[int]:
# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
# '''
# cleaned_text = text_clean(text)
# sequence = [symbol_to_id[symbol] for symbol in cleaned_text]
# return sequence
#
#
# def sequence_to_text(sequence: List[int]) -> str:
# '''Converts a sequence of IDs back to a string'''
# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence)
# return result
def intersperse(sequence, item=0):
result = [item] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def prepare_token_batch(
texts: List[str],
phonemes: Optional[List[str]] = None,
intersperse_blank: bool = True,
blank_id: int = 0,
pad_id: int = 0,
) -> torch.Tensor:
"""Convert a list of text strings into a batch of symbol tokens with padding.
Args:
texts: list of text strings
intersperse_blank: whether to intersperse blank tokens in the converted token sequence.
blank_id: index of blank token
pad_id: padding index
"""
if phonemes is None:
# normalize text
normalized_texts = []
for text in texts:
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_abbreviations(text)
normalized_texts.append(text)
# convert to phonemes
phonemes = phonemize(
normalized_texts,
language='en-us',
backend='espeak',
strip=True,
preserve_punctuation=True,
with_stress=True,
)
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
# convert to symbol ids
lengths = []
sequences = []
skip = False
for idx, sequence in enumerate(phonemes):
try:
sequence = [symbol_to_id[symbol] for symbol in sequence]
except Exception:
# print(texts[idx])
# print(normalized_texts[idx])
print(phonemes[idx])
skip = True
if intersperse_blank:
sequence = intersperse(sequence, blank_id)
try:
sequences.append(torch.tensor(sequence, dtype=torch.int64))
except Exception:
print(sequence)
skip = True
lengths.append(len(sequence))
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
lengths = torch.tensor(lengths, dtype=torch.int64)
return sequences, lengths, skip
class MetricsTracker(collections.defaultdict):
def __init__(self):
# Passing the type 'int' to the base-class constructor
# makes undefined items default to int() which is zero.
# This class will play a role as metrics tracker.
# It can record many metrics, including but not limited to loss.
super(MetricsTracker, self).__init__(int)
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v
for k, v in other.items():
ans[k] = ans[k] + v
return ans
def __mul__(self, alpha: float) -> "MetricsTracker":
ans = MetricsTracker()
for k, v in self.items():
ans[k] = v * alpha
return ans
def __str__(self) -> str:
ans = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
ans += str(k) + "=" + str(norm_value) + ", "
samples = "%.2f" % self["samples"]
ans += "over " + str(samples) + " samples."
return ans
def norm_items(self) -> List[Tuple[str, float]]:
"""
Returns a list of pairs, like:
[('loss_1', 0.1), ('loss_2', 0.07)]
"""
samples = self["samples"] if "samples" in self else 1
ans = []
for k, v in self.items():
if k == "samples":
continue
norm_value = float(v) / samples
ans.append((k, norm_value))
return ans
def reduce(self, device):
"""
Reduce using torch.distributed, which I believe ensures that
all processes get the total.
"""
keys = sorted(self.keys())
s = torch.tensor([float(self[k]) for k in keys], device=device)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
for k, v in zip(keys, s.cpu().tolist()):
self[k] = v
def write_summary(
self,
tb_writer: SummaryWriter,
prefix: str,
batch_idx: int,
) -> None:
"""Add logging information to a TensorBoard writer.
Args:
tb_writer: a TensorBoard writer
prefix: a prefix for the name of the loss, e.g. "train/valid_",
or "train/current_"
batch_idx: The current batch index, used as the x-axis of the plot.
"""
for k, v in self.norm_items():
tb_writer.add_scalar(prefix + k, v, batch_idx)
# checkpoint saving and loading
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
model_avg:
The stored model averaged from the start of training.
params:
User defined parameters, e.g., epoch, loss.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
"sampler": sampler.state_dict() if sampler is not None else None,
}
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)
def save_checkpoint_with_global_batch_idx(
out_dir: Path,
global_batch_idx: int,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer_g: Optional[Optimizer] = None,
optimizer_d: Optional[Optimizer] = None,
scheduler_g: Optional[LRSchedulerType] = None,
scheduler_d: Optional[LRSchedulerType] = None,
scaler: Optional[GradScaler] = None,
sampler: Optional[CutSampler] = None,
rank: int = 0,
):
"""Save training info after processing given number of batches.
Args:
out_dir:
The directory to save the checkpoint.
global_batch_idx:
The number of batches processed so far from the very start of the
training. The saved checkpoint will have the following filename:
f'out_dir / checkpoint-{global_batch_idx}.pt'
model:
The neural network model whose `state_dict` will be saved in the
checkpoint.
params:
A dict of training configurations to be saved.
optimizer_g:
The optimizer for generator used in the training.
Its `state_dict` will be saved.
optimizer_d:
The optimizer for discriminator used in the training.
Its `state_dict` will be saved.
scheduler_g:
The learning rate scheduler for generator used in the training.
Its `state_dict` will be saved.
scheduler_d:
The learning rate scheduler for discriminator used in the training.
Its `state_dict` will be saved.
scaler:
The scaler used for mix precision training. Its `state_dict` will
be saved.
sampler:
The sampler used in the training dataset.
rank:
The rank ID used in DDP training of the current node. Set it to 0
if DDP is not used.
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
save_checkpoint(
filename=filename,
model=model,
params=params,
optimizer_g=optimizer_g,
optimizer_d=optimizer_d,
scheduler_g=scheduler_g,
scheduler_d=scheduler_d,
scaler=scaler,
sampler=sampler,
rank=rank,
)
# def plot_feature(feature):
# """
# Display the feature matrix as an image. Requires matplotlib to be installed.
# """
# import matplotlib.pyplot as plt
#
# feature = np.flip(feature.transpose(1, 0), 0)
# return plt.matshow(feature)
MATPLOTLIB_FLAG = False
def plot_feature(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use("Agg")
MATPLOTLIB_FLAG = True
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
interpolation='none')
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data