mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 23:54:17 +00:00
519 lines
16 KiB
Python
519 lines
16 KiB
Python
# https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/utils/get_random_segments.py
|
|
|
|
# Copyright 2021 Tomoki Hayashi
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""Function to get random segments."""
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
import collections
|
|
import logging
|
|
import re
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.distributed as dist
|
|
from lhotse.dataset.sampling.base import CutSampler
|
|
from pathlib import Path
|
|
from phonemizer import phonemize
|
|
from symbols import symbol_table
|
|
from torch.cuda.amp import GradScaler
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
from torch.optim import Optimizer
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from unidecode import unidecode
|
|
|
|
|
|
def get_random_segments(
|
|
x: torch.Tensor,
|
|
x_lengths: torch.Tensor,
|
|
segment_size: int,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Get random segments.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, C, T).
|
|
x_lengths (Tensor): Length tensor (B,).
|
|
segment_size (int): Segment size.
|
|
|
|
Returns:
|
|
Tensor: Segmented tensor (B, C, segment_size).
|
|
Tensor: Start index tensor (B,).
|
|
|
|
"""
|
|
b, c, t = x.size()
|
|
max_start_idx = x_lengths - segment_size
|
|
max_start_idx[max_start_idx < 0] = 0
|
|
start_idxs = (torch.rand([b]).to(x.device) * max_start_idx).to(
|
|
dtype=torch.long,
|
|
)
|
|
segments = get_segments(x, start_idxs, segment_size)
|
|
|
|
return segments, start_idxs
|
|
|
|
|
|
def get_segments(
|
|
x: torch.Tensor,
|
|
start_idxs: torch.Tensor,
|
|
segment_size: int,
|
|
) -> torch.Tensor:
|
|
"""Get segments.
|
|
|
|
Args:
|
|
x (Tensor): Input tensor (B, C, T).
|
|
start_idxs (Tensor): Start index tensor (B,).
|
|
segment_size (int): Segment size.
|
|
|
|
Returns:
|
|
Tensor: Segmented tensor (B, C, segment_size).
|
|
|
|
"""
|
|
b, c, t = x.size()
|
|
segments = x.new_zeros(b, c, segment_size)
|
|
for i, start_idx in enumerate(start_idxs):
|
|
segments[i] = x[i, :, start_idx : start_idx + segment_size]
|
|
return segments
|
|
|
|
|
|
# https://github.com/espnet/espnet/blob/master/espnet2/torch_utils/device_funcs.py
|
|
def force_gatherable(data, device):
|
|
"""Change object to gatherable in torch.nn.DataParallel recursively
|
|
|
|
The difference from to_device() is changing to torch.Tensor if float or int
|
|
value is found.
|
|
|
|
The restriction to the returned value in DataParallel:
|
|
The object must be
|
|
- torch.cuda.Tensor
|
|
- 1 or more dimension. 0-dimension-tensor sends warning.
|
|
or a list, tuple, dict.
|
|
|
|
"""
|
|
if isinstance(data, dict):
|
|
return {k: force_gatherable(v, device) for k, v in data.items()}
|
|
# DataParallel can't handle NamedTuple well
|
|
elif isinstance(data, tuple) and type(data) is not tuple:
|
|
return type(data)(*[force_gatherable(o, device) for o in data])
|
|
elif isinstance(data, (list, tuple, set)):
|
|
return type(data)(force_gatherable(v, device) for v in data)
|
|
elif isinstance(data, np.ndarray):
|
|
return force_gatherable(torch.from_numpy(data), device)
|
|
elif isinstance(data, torch.Tensor):
|
|
if data.dim() == 0:
|
|
# To 1-dim array
|
|
data = data[None]
|
|
return data.to(device)
|
|
elif isinstance(data, float):
|
|
return torch.tensor([data], dtype=torch.float, device=device)
|
|
elif isinstance(data, int):
|
|
return torch.tensor([data], dtype=torch.long, device=device)
|
|
elif data is None:
|
|
return None
|
|
else:
|
|
warnings.warn(f"{type(data)} may not be gatherable by DataParallel")
|
|
return data
|
|
|
|
|
|
# The following codes are based on https://github.com/jaywalnut310/vits
|
|
|
|
# Regular expression matching whitespace:
|
|
_whitespace_re = re.compile(r'\s+')
|
|
|
|
# List of (regular expression, replacement) pairs for abbreviations:
|
|
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
|
('mrs', 'misess'),
|
|
('mr', 'mister'),
|
|
('dr', 'doctor'),
|
|
('st', 'saint'),
|
|
('co', 'company'),
|
|
('jr', 'junior'),
|
|
('maj', 'major'),
|
|
('gen', 'general'),
|
|
('drs', 'doctors'),
|
|
('rev', 'reverend'),
|
|
('lt', 'lieutenant'),
|
|
('hon', 'honorable'),
|
|
('sgt', 'sergeant'),
|
|
('capt', 'captain'),
|
|
('esq', 'esquire'),
|
|
('ltd', 'limited'),
|
|
('col', 'colonel'),
|
|
('ft', 'fort'),
|
|
]]
|
|
|
|
|
|
def expand_abbreviations(text):
|
|
for regex, replacement in _abbreviations:
|
|
text = re.sub(regex, replacement, text)
|
|
return text
|
|
|
|
|
|
def lowercase(text):
|
|
return text.lower()
|
|
|
|
|
|
def collapse_whitespace(text):
|
|
return re.sub(_whitespace_re, ' ', text)
|
|
|
|
|
|
def convert_to_ascii(text):
|
|
return unidecode(text)
|
|
|
|
|
|
def text_clean(text):
|
|
'''Pipeline for English text, including abbreviation expansion. + punctuation + stress.
|
|
|
|
Returns:
|
|
A string of phonemes.
|
|
'''
|
|
text = convert_to_ascii(text)
|
|
text = lowercase(text)
|
|
text = expand_abbreviations(text)
|
|
phonemes = phonemize(
|
|
text,
|
|
language='en-us',
|
|
backend='espeak',
|
|
strip=True,
|
|
preserve_punctuation=True,
|
|
with_stress=True,
|
|
)
|
|
phonemes = collapse_whitespace(phonemes)
|
|
return phonemes
|
|
|
|
|
|
# Mappings from symbol to numeric ID and vice versa:
|
|
symbol_to_id = {s: i for i, s in enumerate(symbol_table)}
|
|
id_to_symbol = {i: s for i, s in enumerate(symbol_table)}
|
|
|
|
|
|
# def text_to_sequence(text: str) -> List[int]:
|
|
# '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
|
# '''
|
|
# cleaned_text = text_clean(text)
|
|
# sequence = [symbol_to_id[symbol] for symbol in cleaned_text]
|
|
# return sequence
|
|
#
|
|
#
|
|
# def sequence_to_text(sequence: List[int]) -> str:
|
|
# '''Converts a sequence of IDs back to a string'''
|
|
# result = ''.join(id_to_symbol[symbol_id] for symbol_id in sequence)
|
|
# return result
|
|
|
|
|
|
def intersperse(sequence, item=0):
|
|
result = [item] * (len(sequence) * 2 + 1)
|
|
result[1::2] = sequence
|
|
return result
|
|
|
|
|
|
def prepare_token_batch(
|
|
texts: List[str],
|
|
phonemes: Optional[List[str]] = None,
|
|
intersperse_blank: bool = True,
|
|
blank_id: int = 0,
|
|
pad_id: int = 0,
|
|
) -> torch.Tensor:
|
|
"""Convert a list of text strings into a batch of symbol tokens with padding.
|
|
Args:
|
|
texts: list of text strings
|
|
intersperse_blank: whether to intersperse blank tokens in the converted token sequence.
|
|
blank_id: index of blank token
|
|
pad_id: padding index
|
|
"""
|
|
if phonemes is None:
|
|
# normalize text
|
|
normalized_texts = []
|
|
for text in texts:
|
|
text = convert_to_ascii(text)
|
|
text = lowercase(text)
|
|
text = expand_abbreviations(text)
|
|
normalized_texts.append(text)
|
|
|
|
# convert to phonemes
|
|
phonemes = phonemize(
|
|
normalized_texts,
|
|
language='en-us',
|
|
backend='espeak',
|
|
strip=True,
|
|
preserve_punctuation=True,
|
|
with_stress=True,
|
|
)
|
|
phonemes = [collapse_whitespace(sequence) for sequence in phonemes]
|
|
|
|
# convert to symbol ids
|
|
lengths = []
|
|
sequences = []
|
|
skip = False
|
|
for idx, sequence in enumerate(phonemes):
|
|
try:
|
|
sequence = [symbol_to_id[symbol] for symbol in sequence]
|
|
except Exception:
|
|
# print(texts[idx])
|
|
# print(normalized_texts[idx])
|
|
print(phonemes[idx])
|
|
skip = True
|
|
if intersperse_blank:
|
|
sequence = intersperse(sequence, blank_id)
|
|
try:
|
|
sequences.append(torch.tensor(sequence, dtype=torch.int64))
|
|
except Exception:
|
|
print(sequence)
|
|
skip = True
|
|
lengths.append(len(sequence))
|
|
|
|
sequences = pad_sequence(sequences, batch_first=True, padding_value=pad_id)
|
|
lengths = torch.tensor(lengths, dtype=torch.int64)
|
|
return sequences, lengths, skip
|
|
|
|
|
|
class MetricsTracker(collections.defaultdict):
|
|
def __init__(self):
|
|
# Passing the type 'int' to the base-class constructor
|
|
# makes undefined items default to int() which is zero.
|
|
# This class will play a role as metrics tracker.
|
|
# It can record many metrics, including but not limited to loss.
|
|
super(MetricsTracker, self).__init__(int)
|
|
|
|
def __add__(self, other: "MetricsTracker") -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v
|
|
for k, v in other.items():
|
|
ans[k] = ans[k] + v
|
|
return ans
|
|
|
|
def __mul__(self, alpha: float) -> "MetricsTracker":
|
|
ans = MetricsTracker()
|
|
for k, v in self.items():
|
|
ans[k] = v * alpha
|
|
return ans
|
|
|
|
def __str__(self) -> str:
|
|
ans = ""
|
|
for k, v in self.norm_items():
|
|
norm_value = "%.4g" % v
|
|
ans += str(k) + "=" + str(norm_value) + ", "
|
|
samples = "%.2f" % self["samples"]
|
|
ans += "over " + str(samples) + " samples."
|
|
return ans
|
|
|
|
def norm_items(self) -> List[Tuple[str, float]]:
|
|
"""
|
|
Returns a list of pairs, like:
|
|
[('loss_1', 0.1), ('loss_2', 0.07)]
|
|
"""
|
|
samples = self["samples"] if "samples" in self else 1
|
|
ans = []
|
|
for k, v in self.items():
|
|
if k == "samples":
|
|
continue
|
|
norm_value = float(v) / samples
|
|
ans.append((k, norm_value))
|
|
return ans
|
|
|
|
def reduce(self, device):
|
|
"""
|
|
Reduce using torch.distributed, which I believe ensures that
|
|
all processes get the total.
|
|
"""
|
|
keys = sorted(self.keys())
|
|
s = torch.tensor([float(self[k]) for k in keys], device=device)
|
|
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
|
for k, v in zip(keys, s.cpu().tolist()):
|
|
self[k] = v
|
|
|
|
def write_summary(
|
|
self,
|
|
tb_writer: SummaryWriter,
|
|
prefix: str,
|
|
batch_idx: int,
|
|
) -> None:
|
|
"""Add logging information to a TensorBoard writer.
|
|
|
|
Args:
|
|
tb_writer: a TensorBoard writer
|
|
prefix: a prefix for the name of the loss, e.g. "train/valid_",
|
|
or "train/current_"
|
|
batch_idx: The current batch index, used as the x-axis of the plot.
|
|
"""
|
|
for k, v in self.norm_items():
|
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
|
|
|
|
|
# checkpoint saving and loading
|
|
LRSchedulerType = torch.optim.lr_scheduler._LRScheduler
|
|
|
|
|
|
def save_checkpoint(
|
|
filename: Path,
|
|
model: Union[nn.Module, DDP],
|
|
params: Optional[Dict[str, Any]] = None,
|
|
optimizer_g: Optional[Optimizer] = None,
|
|
optimizer_d: Optional[Optimizer] = None,
|
|
scheduler_g: Optional[LRSchedulerType] = None,
|
|
scheduler_d: Optional[LRSchedulerType] = None,
|
|
scaler: Optional[GradScaler] = None,
|
|
sampler: Optional[CutSampler] = None,
|
|
rank: int = 0,
|
|
) -> None:
|
|
"""Save training information to a file.
|
|
|
|
Args:
|
|
filename:
|
|
The checkpoint filename.
|
|
model:
|
|
The model to be saved. We only save its `state_dict()`.
|
|
model_avg:
|
|
The stored model averaged from the start of training.
|
|
params:
|
|
User defined parameters, e.g., epoch, loss.
|
|
optimizer_g:
|
|
The optimizer for generator used in the training.
|
|
Its `state_dict` will be saved.
|
|
optimizer_d:
|
|
The optimizer for discriminator used in the training.
|
|
Its `state_dict` will be saved.
|
|
scheduler_g:
|
|
The learning rate scheduler for generator used in the training.
|
|
Its `state_dict` will be saved.
|
|
scheduler_d:
|
|
The learning rate scheduler for discriminator used in the training.
|
|
Its `state_dict` will be saved.
|
|
scalar:
|
|
The GradScaler to be saved. We only save its `state_dict()`.
|
|
rank:
|
|
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
|
Returns:
|
|
Return None.
|
|
"""
|
|
if rank != 0:
|
|
return
|
|
|
|
logging.info(f"Saving checkpoint to {filename}")
|
|
|
|
if isinstance(model, DDP):
|
|
model = model.module
|
|
|
|
checkpoint = {
|
|
"model": model.state_dict(),
|
|
"optimizer_g": optimizer_g.state_dict() if optimizer_g is not None else None,
|
|
"optimizer_d": optimizer_d.state_dict() if optimizer_d is not None else None,
|
|
"scheduler_g": scheduler_g.state_dict() if scheduler_g is not None else None,
|
|
"scheduler_d": scheduler_d.state_dict() if scheduler_d is not None else None,
|
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
|
"sampler": sampler.state_dict() if sampler is not None else None,
|
|
}
|
|
|
|
if params:
|
|
for k, v in params.items():
|
|
assert k not in checkpoint
|
|
checkpoint[k] = v
|
|
|
|
torch.save(checkpoint, filename)
|
|
|
|
|
|
def save_checkpoint_with_global_batch_idx(
|
|
out_dir: Path,
|
|
global_batch_idx: int,
|
|
model: Union[nn.Module, DDP],
|
|
params: Optional[Dict[str, Any]] = None,
|
|
optimizer_g: Optional[Optimizer] = None,
|
|
optimizer_d: Optional[Optimizer] = None,
|
|
scheduler_g: Optional[LRSchedulerType] = None,
|
|
scheduler_d: Optional[LRSchedulerType] = None,
|
|
scaler: Optional[GradScaler] = None,
|
|
sampler: Optional[CutSampler] = None,
|
|
rank: int = 0,
|
|
):
|
|
"""Save training info after processing given number of batches.
|
|
|
|
Args:
|
|
out_dir:
|
|
The directory to save the checkpoint.
|
|
global_batch_idx:
|
|
The number of batches processed so far from the very start of the
|
|
training. The saved checkpoint will have the following filename:
|
|
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
|
model:
|
|
The neural network model whose `state_dict` will be saved in the
|
|
checkpoint.
|
|
params:
|
|
A dict of training configurations to be saved.
|
|
optimizer_g:
|
|
The optimizer for generator used in the training.
|
|
Its `state_dict` will be saved.
|
|
optimizer_d:
|
|
The optimizer for discriminator used in the training.
|
|
Its `state_dict` will be saved.
|
|
scheduler_g:
|
|
The learning rate scheduler for generator used in the training.
|
|
Its `state_dict` will be saved.
|
|
scheduler_d:
|
|
The learning rate scheduler for discriminator used in the training.
|
|
Its `state_dict` will be saved.
|
|
scaler:
|
|
The scaler used for mix precision training. Its `state_dict` will
|
|
be saved.
|
|
sampler:
|
|
The sampler used in the training dataset.
|
|
rank:
|
|
The rank ID used in DDP training of the current node. Set it to 0
|
|
if DDP is not used.
|
|
"""
|
|
out_dir = Path(out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
|
save_checkpoint(
|
|
filename=filename,
|
|
model=model,
|
|
params=params,
|
|
optimizer_g=optimizer_g,
|
|
optimizer_d=optimizer_d,
|
|
scheduler_g=scheduler_g,
|
|
scheduler_d=scheduler_d,
|
|
scaler=scaler,
|
|
sampler=sampler,
|
|
rank=rank,
|
|
)
|
|
|
|
|
|
# def plot_feature(feature):
|
|
# """
|
|
# Display the feature matrix as an image. Requires matplotlib to be installed.
|
|
# """
|
|
# import matplotlib.pyplot as plt
|
|
#
|
|
# feature = np.flip(feature.transpose(1, 0), 0)
|
|
# return plt.matshow(feature)
|
|
|
|
MATPLOTLIB_FLAG = False
|
|
|
|
|
|
def plot_feature(spectrogram):
|
|
global MATPLOTLIB_FLAG
|
|
if not MATPLOTLIB_FLAG:
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
MATPLOTLIB_FLAG = True
|
|
mpl_logger = logging.getLogger('matplotlib')
|
|
mpl_logger.setLevel(logging.WARNING)
|
|
import matplotlib.pylab as plt
|
|
import numpy as np
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 2))
|
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
|
interpolation='none')
|
|
plt.colorbar(im, ax=ax)
|
|
plt.xlabel("Frames")
|
|
plt.ylabel("Channels")
|
|
plt.tight_layout()
|
|
|
|
fig.canvas.draw()
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
|
plt.close()
|
|
return data
|