Add CTC training.

This commit is contained in:
Fangjun Kuang 2021-07-24 17:13:20 +08:00
parent a01d08f73c
commit f3542c7793
22 changed files with 2196 additions and 8 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
data
__pycache__
path.sh
exp

View File

@ -15,7 +15,7 @@ repos:
rev: 5.9.2
hooks:
- id: isort
args: [--profile=black]
args: [--profile=black, --line-length=80]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1

View File

@ -0,0 +1,82 @@
#!/usr/bin/env python3
"""
This script compiles HLG from
- H, the ctc topology, built from phones contained in data/lang/lexicon.txt
- L, the lexicon, built from data/lang/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
The generated HLG is saved in data/lm/HLG.pt
"""
import k2
import torch
from icefall.lexicon import Lexicon
def main():
lexicon = Lexicon("data/lang")
max_token_id = max(lexicon.tokens)
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt"))
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_token_disambig_id = lexicon.phones["#0"]
first_word_disambig_id = lexicon.words["#0"]
L = k2.arc_sort(L)
G = k2.arc_sort(G)
print("Intersecting L and G")
LG = k2.compose(L, G)
print(f"LG shape: {LG.shape}")
print("Connecting LG")
LG = k2.connect(LG)
print(f"LG shape after k2.connect: {LG.shape}")
print(type(LG.aux_labels))
print("Determinizing LG")
LG = k2.determinize(LG)
print(type(LG.aux_labels))
print("Connecting LG after k2.determinize")
LG = k2.connect(LG)
print("Removing disambiguation symbols on LG")
LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedInt)
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
print(f"LG shape after k2.remove_epsilon: {LG.shape}")
LG = k2.connect(LG)
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
print("Arc sorting LG")
LG = k2.arc_sort(LG)
print("Composing H and LG")
HLG = k2.compose(H, LG, inner_labels="phones")
print("Connecting LG")
HLG = k2.connect(HLG)
print("Arc sorting LG")
HLG = k2.arc_sort(HLG)
print("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
if __name__ == "__main__":
main()

View File

@ -231,14 +231,18 @@ def add_self_loops(
arcs:
A list-of-list. The sublist contains
`[src_state, dest_state, label, aux_label, score]`
disambig_phone:
It is the phone ID of the symbol `#0`.
disambig_word:
It is the word ID of the symbol `#0`.
Return:
Return new `arcs` that contain self-loops.
Return new `arcs` containing self-loops.
"""
states_needs_self_loops = set()
for arc in arcs:
src, dst, ilable, olable, score = arc
if olable != 0:
src, dst, ilabel, olabel, score = arc
if olabel != 0:
states_needs_self_loops.add(src)
ans = []
@ -396,11 +400,11 @@ def main():
sil_prob=sil_prob,
need_self_loops=True,
)
# Just for debugging, will remove it
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
torch.save(L.as_dict(), out_dir / "L.pt")
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")

View File

@ -80,7 +80,18 @@ def test_read_lexicon(filename: str):
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
if __name__ == "__main__":
def test_lexicon():
from icefall.lexicon import Lexicon
lexicon = Lexicon("data/lang")
print(lexicon.tokens)
def main():
filename = generate_lexicon_file()
test_read_lexicon(filename)
os.remove(filename)
if __name__ == "__main__":
test_lexicon()

View File

@ -87,3 +87,24 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
./local/prepare_lang.py
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
echo "Stage 6: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
if [ ! -e data/lm/G_3_gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
echo "Stage 7: Compile HLG"
if [ ! -f data/lm/HLG.pt ]; then
python3 ./local/compile_hlg.py
fi
fi

View File

@ -0,0 +1,14 @@
## (To be filled in)
It will contain:
- How to run
- WERs
```bash
cd $PWD/..
./prepare.sh
./tdnn_lstm_ctc/train.py
```

View File

@ -0,0 +1,210 @@
#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=9,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp3/"),
"lang_dir": Path("data/lang"),
"feature_dim": 80,
"subsampling_factor": 3,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
}
)
return params
@torch.no_grad()
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
) -> List[Tuple[List[str], List[str]]]:
"""Decode one batch and return a list of tuples containing
`(ref_words, hyp_words)`.
Args:
params:
It is the return value of :func:`get_params`.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
nnet_output = model(feature)
# nnet_output is [N, T, C]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
lattices = k2.intersect_dense_pruned(
HLG,
dense_fsa_vec,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_paths = k2.shortest_path(lattices, use_double_scores=True)
hyps = get_texts(best_paths)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
texts = supervisions["text"]
results = []
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
results.append((ref_words, hyp_words))
return results
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
tot_num_cuts = len(test_dl.dataset.cuts)
num_cuts = 0
results = []
for batch_idx, batch in enumerate(test_dl):
this_batch = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
)
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
)
errs_filename = params.exp_dir / f"errs-{test_set}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, test_set, results)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,86 @@
import torch
import torch.nn as nn
class TdnnLstm(nn.Module):
def __init__(
self, num_features: int, num_classes: int, subsampling_factor: int = 3
) -> None:
"""
Args:
num_features:
The input dimension of the model.
num_classes:
The output dimension of the model.
subsampling_factor:
It reduces the number of output frames by this factor.
"""
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=500,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=500, affine=False),
nn.Conv1d(
in_channels=500,
out_channels=500,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=500, affine=False),
nn.Conv1d(
in_channels=500,
out_channels=500,
kernel_size=3,
stride=self.subsampling_factor, # stride: subsampling_factor!
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=500, affine=False),
)
self.lstms = nn.ModuleList(
[
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
for _ in range(5)
]
)
self.lstm_bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
)
self.dropout = nn.Dropout(0.2)
self.linear = nn.Linear(in_features=500, out_features=self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Its shape is [N, C, T]
Returns:
The output tensor has shape [N, T, C]
"""
x = self.tdnn(x)
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
for lstm, bnorm in zip(self.lstms, self.lstm_bnorms):
x_new, _ = lstm(x)
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
2, 0, 1
) # (T, N, C) -> (N, C, T) -> (T, N, C)
x_new = self.dropout(x_new)
x = x_new + x # skip connections
x = x.transpose(
1, 0
) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim
x = self.linear(x)
x = nn.functional.log_softmax(x, dim=-1)
return x

View File

@ -0,0 +1,493 @@
#!/usr/bin/env python3
# This is just at the very beginning ...
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional
import k2
import torch
import torch.nn as nn
import torch.optim as optim
from model import TdnnLstm
from torch.nn.utils import clip_grad_value_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, encode_supervisions, setup_logger
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
is saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- exp_dir: It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
and continue training from that checkpoint.
- num_epochs: Number of epochs to train.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval` is 0
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
"""
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp"),
"lang_dir": Path("data/lang"),
"lr": 1e-3,
"feature_dim": 80,
"weight_decay": 5e-4,
"subsampling_factor": 3,
"start_epoch": 0,
"num_epochs": 10,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"valid_interval": 1000,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: CtcTrainingGraphCompiler,
is_training: bool,
):
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of TdnnLstm in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
assert feature.ndim == 3
feature = feature.to(device)
with torch.set_grad_enabled(is_training):
nnet_output = model(feature)
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervisions = batch["supervisions"]
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
decoding_graph = graph_compiler.compile(texts)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
assert loss.requires_grad == is_training
# train_frames and valid_frames are used for printing.
if is_training:
params.train_frames = supervision_segments[:, 2].sum().item()
else:
params.valid_frames = supervision_segments[:, 2].sum().item()
return loss
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: CtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
) -> None:
"""Run the validation process. The validation loss
is saved in `params.valid_loss`.
"""
model.eval()
tot_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu
tot_frames += params.valid_frames
params.valid_loss = tot_loss / tot_frames
if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = params.valid_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: CtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
"""
model.train()
tot_loss = 0.0 # sum of losses over all batches
tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()
tot_frames += params.train_frames
tot_loss += loss_cpu
tot_avg_loss = tot_loss / tot_frames
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss}, "
f"best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
params.train_loss = tot_loss / tot_frames
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
model.to(device)
optimizer = optim.AdamW(
model.parameters(),
lr=params.lr,
weight_decay=params.weight_decay,
)
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
load_checkpoint_if_available(
params=params, model=model, optimizer=optimizer
)
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
if epoch > params.start_epoch:
logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}")
if tb_writer is not None:
tb_writer.add_scalar(
"train/lr",
scheduler.get_last_lr()[0],
params.batch_idx_train,
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
)
scheduler.step()
save_checkpoint(
params=params, model=model, optimizer=optimizer, scheduler=scheduler
)
logging.info("Done!")
if __name__ == "__main__":
main()

131
icefall/checkpoint.py Normal file
View File

@ -0,0 +1,131 @@
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import torch.nn as nn
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
def save_checkpoint(
filename: Path,
model: Union[nn.Module, DDP],
params: Optional[Dict[str, Any]] = None,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save training information to a file.
Args:
filename:
The checkpoint filename.
model:
The model to be saved. We only save its `state_dict()`.
params:
User defined parameters, e.g., epoch, loss.
optimizer:
The optimizer to be saved. We only save its `state_dict()`.
scheduler:
The scheduler to be saved. We only save its `state_dict()`.
scalar:
The GradScaler to be saved. We only save its `state_dict()`.
rank:
Used in DDP. We save checkpoint only for the node whose rank is 0.
Returns:
Return None.
"""
if rank != 0:
return
logging.info(f"Saving checkpoint to {filename}")
if isinstance(model, DDP):
model = model.module
checkpoint = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"scheduler": scheduler.state_dict() if scheduler is not None else None,
"grad_scaler": scaler.state_dict() if scaler is not None else None,
}
if params:
for k, v in params.items():
assert k not in checkpoint
checkpoint[k] = v
torch.save(checkpoint, filename)
def load_checkpoint(
filename: Path,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
scheduler: Optional[_LRScheduler] = None,
scaler: Optional[GradScaler] = None,
) -> Dict[str, Any]:
"""
TODO: document it
"""
logging.info(f"Loading checkpoint from {filename}")
checkpoint = torch.load(filename, map_location="cpu")
if next(iter(checkpoint["model"])).startswith("module."):
logging.info("Loading checkpoint saved by DDP")
dst_state_dict = model.state_dict()
src_state_dict = checkpoint["model"]
for key in dst_state_dict.keys():
src_key = "{}.{}".format("module", key)
dst_state_dict[key] = src_state_dict.pop(src_key)
assert len(src_state_dict) == 0
model.load_state_dict(dst_state_dict, strict=False)
else:
model.load_state_dict(checkpoint["model"], strict=False)
checkpoint.pop("model")
def load(name, obj):
s = checkpoint[name]
if obj and s:
obj.load_state_dict(s)
checkpoint.pop(name)
load("optimizer", optimizer)
load("scheduler", scheduler)
load("grad_scaler", scaler)
return checkpoint
def average_checkpoints(filenames: List[Path]) -> dict:
"""Average a list of checkpoints.
Args:
filenames:
Filenames of the checkpoints to be averaged. We assume all
checkpoints are saved by :func:`save_checkpoint`.
Returns:
Return a dict (i.e., state_dict) which is the average of all
model state dicts contained in the checkpoints.
"""
n = len(filenames)
avg = torch.load(filenames[0], map_location="cpu")["model"]
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
for k in avg:
avg[k] += state_dict[k]
for k in avg:
if avg[k].is_floating_point():
avg[k] /= n
else:
avg[k] //= n
return avg

View File

View File

@ -0,0 +1,248 @@
import argparse
import logging
from pathlib import Path
from typing import List, Union
from lhotse import Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
SingleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class AsrDataModule(DataModule):
"""
DataModule for K2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--feature-dir",
type=Path,
default=Path("data/fbank"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=500.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=False,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the BucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=True,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
cuts_train = self.train_cuts()
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between
# different utterances.
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
input_transforms = [
SpecAugment(
num_frame_masks=2,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
]
train = K2SpeechRecognitionDataset(
cuts_train,
cut_transforms=transforms,
input_transforms=input_transforms,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
cuts_train = cuts_train.drop_features()
train = K2SpeechRecognitionDataset(
cuts=cuts_train,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
)
if self.args.bucketing_sampler:
logging.info("Using BucketingSampler.")
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
num_buckets=self.args.num_buckets,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=4,
persistent_workers=True,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
cuts_valid = cuts_valid.drop_features()
validate = K2SpeechRecognitionDataset(
cuts_valid.drop_features(),
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
)
else:
validate = K2SpeechRecognitionDataset(cuts_valid)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
)
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
cuts = self.test_cuts()
is_list = isinstance(cuts, list)
test_loaders = []
if not is_list:
cuts = [cuts]
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
cuts_test,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test, batch_size=None, sampler=sampler, num_workers=1
)
test_loaders.append(test_dl)
if is_list:
return test_loaders
else:
return test_loaders[0]

View File

@ -0,0 +1,43 @@
import argparse
from typing import List, Union
from lhotse import CutSet
from torch.utils.data import DataLoader
class DataModule:
"""
Contains dataset-related code. It is intended to read/construct Lhotse cuts,
and create Dataset/Sampler/DataLoader out of them.
There is a separate method to create each of train/valid/test DataLoader.
In principle, there might be multiple DataLoaders for each of
train/valid/test
(e.g. when a corpus has multiple test sets).
The API of this class allows to return lists of CutSets/DataLoaders.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
pass
def train_cuts(self) -> Union[CutSet, List[CutSet]]:
raise NotImplementedError()
def valid_cuts(self) -> Union[CutSet, List[CutSet]]:
raise NotImplementedError()
def test_cuts(self) -> Union[CutSet, List[CutSet]]:
raise NotImplementedError()
def train_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
raise NotImplementedError()
def valid_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
raise NotImplementedError()
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
raise NotImplementedError()

View File

@ -0,0 +1,68 @@
import argparse
import logging
from functools import lru_cache
from typing import List
from lhotse import CutSet, load_manifest
from icefall.dataset.asr_datamodule import AsrDataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(AsrDataModule):
"""
LibriSpeech ASR data module. Can be used for 100h subset
(``--full-libri false``) or full 960h set.
The train and valid cuts for standard Libri splits are
concatenated into a single CutSet/DataLoader.
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(title="LibriSpeech specific options")
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech.",
)
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts

109
icefall/graph_compiler.py Normal file
View File

@ -0,0 +1,109 @@
from typing import List
import k2
import torch
from icefall.lexicon import Lexicon
class CtcTrainingGraphCompiler(object):
def __init__(
self,
lexicon: Lexicon,
device: torch.device,
oov: str = "<UNK>",
):
"""
Args:
lexicon:
It is built from `data/lang/lexicon.txt`.
device:
The device to use for operations compiling transcripts to FSAs.
oov:
Out of vocabulary word. When a word in the transcript
does not exist in the lexicon, it is replaced with `oov`.
"""
L_inv = lexicon.L_inv.to(device)
assert L_inv.requires_grad is False
assert oov in lexicon.words
self.L_inv = k2.arc_sort(L_inv)
self.oov_id = lexicon.words[oov]
self.words = lexicon.words
max_token_id = max(lexicon.tokens)
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
self.ctc_topo = ctc_topo.to(device)
self.device = device
def compile(self, texts: List[str]) -> k2.Fsa:
"""Build decoding graphs by composing ctc_topo with
given transcripts.
Args:
texts:
A list of strings. Each string contains a sentence for an utterance.
A sentence consists of spaces separated words. An example `texts`
looks like:
['hello icefall', 'CTC training with k2']
Returns:
An FsaVec, the composition result of `self.ctc_topo` and the
transcript FSA.
"""
transcript_fsa = self.convert_transcript_to_fsa(texts)
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
# is False, so we add epsilon self-loops here
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
transcript_fsa
)
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
decoding_graph = k2.compose(
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
)
assert decoding_graph.requires_grad is False
return decoding_graph
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
"""Convert a list of transcript texts to an FsaVec.
Args:
texts:
A list of strings. Each string contains a sentence for an utterance.
A sentence consists of spaces separated words. An example `texts`
looks like:
['hello icefall', 'CTC training with k2']
Returns:
Return an FsaVec, whose `shape[0]` equals to `len(texts)`.
"""
word_ids_list = []
for text in texts:
word_ids = []
for word in text.split(" "):
if word in self.words:
word_ids.append(self.words[word])
else:
word_ids.append(self.oov_id)
word_ids_list.append(word_ids)
word_fsa = k2.linear_fsa(word_ids_list, self.device)
word_fsa_with_self_loops = k2.add_epsilon_self_loops(word_fsa)
fsa = k2.intersect(
self.L_inv, word_fsa_with_self_loops, treat_epsilons_specially=False
)
# fsa has word ID as labels and token ID as aux_labels, so
# we need to invert it
ans_fsa = fsa.invert_()
return k2.arc_sort(ans_fsa)

66
icefall/lexicon.py Normal file
View File

@ -0,0 +1,66 @@
import logging
import re
from pathlib import Path
from typing import List
import k2
import torch
class Lexicon(object):
"""Phone based lexicon.
TODO: Add BpeLexicon for BPE models.
"""
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$")
):
"""
Args:
lang_dir:
Path to the lang director. It is expected to contain the following
files:
- phones.txt
- words.txt
- L.pt
The above files are produced by the script `prepare.sh`. You
should have run that before running the training code.
disambig_pattern:
It contains the pattern for disambiguation symbols.
"""
lang_dir = Path(lang_dir)
self.phones = k2.SymbolTable.from_file(lang_dir / "phones.txt")
self.words = k2.SymbolTable.from_file(lang_dir / "words.txt")
if (lang_dir / "Linv.pt").exists():
logging.info("Loading pre-compiled Linv.pt")
L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
else:
logging.info("Converting L.pt to Linv.pt")
L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt"))
L_inv = k2.arc_sort(L.invert())
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
# We save L_inv instead of L because it will be used to intersect with
# transcript, both of whose labels are word IDs.
self.L_inv = L_inv
self.disambig_pattern = disambig_pattern
@property
def tokens(self) -> List[int]:
"""Return a list of phone IDs excluding those from
disambiguation symbols.
Caution:
0 is not a phone ID so it is excluded from the return value.
"""
symbols = self.phones.symbols
ans = []
for s in symbols:
if not self.disambig_pattern.match(s):
ans.append(self.phones[s])
if 0 in ans:
ans.remove(0)
ans.sort()
return ans

View File

@ -1,5 +1,20 @@
import argparse
import logging
import os
import subprocess
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Dict, List, TextIO, Tuple, Union
import k2
import k2.ragged as k2r
import kaldialign
import torch
import torch.distributed as dist
Pathlike = Union[str, Path]
@contextmanager
@ -32,3 +47,286 @@ def get_executor():
# No need to return anything - compute_and_store_features
# will just instantiate the pool itself.
yield None
def str2bool(v):
"""Used in argparse.ArgumentParser.add_argument to indicate
that a type is a bool type and user can enter
- yes, true, t, y, 1, to represent True
- no, false, f, n, 0, to represent False
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def setup_logger(
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
) -> None:
"""Setup log level.
Args:
log_filename:
The filename to save the log.
log_level:
The log level to use, e.g., "debug", "info", "warning", "error",
"critical"
"""
now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
log_filename = f"{log_filename}-{date_time}-{rank}"
else:
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
log_filename = f"{log_filename}-{date_time}"
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
level = logging.ERROR
if log_level == "debug":
level = logging.DEBUG
elif log_level == "info":
level = logging.INFO
elif log_level == "warning":
level = logging.WARNING
elif log_level == "critical":
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename, format=formatter, level=level, filemode="w"
)
if use_console:
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger("").addHandler(console)
def get_env_info():
"""
TODO:
"""
return {
"k2-git-sha1": None,
"k2-version": None,
"lhotse-version": None,
"torch-version": None,
"icefall-sha1": None,
"icefall-version": None,
}
# See
# https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute # noqa
class AttributeDict(dict):
__slots__ = ()
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
def encode_supervisions(
supervisions: Dict[str, torch.Tensor], subsampling_factor: int
) -> Tuple[torch.Tensor, List[str]]:
"""
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
and a list of transcription strings.
The supervision tensor has shape ``(batch_size, 3)``.
Its second dimension contains information about sequence index [0],
start frames [1] and num frames [2].
The batch items might become re-ordered during this operation -- the
returned tensor and list of strings are guaranteed to be consistent with
each other.
"""
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
),
1,
).to(torch.int32)
indices = torch.argsort(supervision_segments[:, 2], descending=True)
supervision_segments = supervision_segments[indices]
texts = supervisions["text"]
texts = [texts[idx] for idx in indices]
return supervision_segments, texts
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
"""Extract the texts from the best-path FSAs.
Args:
best_paths:
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
containing multiple FSAs, which is expected to be the result
of k2.shortest_path (otherwise the returned values won't
be meaningful).
Returns:
Returns a list of lists of int, containing the label sequences we
decoded.
"""
if isinstance(best_paths.aux_labels, k2.RaggedInt):
# remove 0's and -1's.
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
aux_shape = k2r.compose_ragged_shapes(
best_paths.arcs.shape(), aux_labels.shape()
)
# remove the states and arcs axes.
aux_shape = k2r.remove_axis(aux_shape, 1)
aux_shape = k2r.remove_axis(aux_shape, 1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
else:
# remove axis corresponding to states.
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = k2r.remove_values_leq(aux_labels, 0)
assert aux_labels.num_axes() == 2
return k2r.to_list(aux_labels)
def write_error_stats(
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
) -> float:
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
ref_len = sum([len(r) for r, _ in results])
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
for ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR)
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
" ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted(
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print(
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate)

50
test/test_checkpoint.py Normal file
View File

@ -0,0 +1,50 @@
#!/usr/bin/env python3
import pytest
import torch
import torch.nn as nn
from icefall.checkpoint import (
average_checkpoints,
load_checkpoint,
save_checkpoint,
)
@pytest.fixture
def checkpoints1(tmp_path):
f = tmp_path / "f.pt"
m = nn.Module()
m.p1 = nn.Parameter(torch.tensor([10.0, 20.0]), requires_grad=False)
m.register_buffer("p2", torch.tensor([10, 100]))
params = {"a": 10, "b": 20}
save_checkpoint(f, m, params=params)
return f
@pytest.fixture
def checkpoints2(tmp_path):
f = tmp_path / "f2.pt"
m = nn.Module()
m.p1 = nn.Parameter(torch.Tensor([50, 30.0]))
m.register_buffer("p2", torch.tensor([1, 3]))
params = {"a": 100, "b": 200}
save_checkpoint(f, m, params=params)
return f
def test_load_checkpoints(checkpoints1):
m = nn.Module()
m.p1 = nn.Parameter(torch.Tensor([0, 0.0]))
m.p2 = nn.Parameter(torch.Tensor([0, 0]))
params = load_checkpoint(checkpoints1, m)
assert torch.allclose(m.p1, torch.Tensor([10.0, 20]))
assert params == {"a": 10, "b": 20}
def test_average_checkpoints(checkpoints1, checkpoints2):
state_dict = average_checkpoints([checkpoints1, checkpoints2])
assert torch.allclose(state_dict["p1"], torch.Tensor([30, 25.0]))
assert torch.allclose(state_dict["p2"], torch.tensor([5, 51]))

160
test/test_graph_compiler.py Normal file
View File

@ -0,0 +1,160 @@
#!/usr/bin/env python3
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
import re
import k2
import pytest
import torch
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import get_texts
@pytest.fixture
def lexicon():
"""
We use the following test data:
lexicon.txt
foo f o o
bar b a r
baz b a z
<UNK> SPN
phones.txt
<eps> 0
a 1
b 2
f 3
o 4
r 5
z 6
SPN 7
words.txt:
<eps> 0
foo 1
bar 2
baz 3
<UNK> 4
"""
L = k2.Fsa.from_str(
"""
0 0 7 4 0
0 7 -1 -1 0
0 1 3 1 0
0 3 2 2 0
0 5 2 3 0
1 2 4 0 0
2 0 4 0 0
3 4 1 0 0
4 0 5 0 0
5 6 1 0 0
6 0 6 0 0
7
""",
num_aux_labels=1,
)
L.labels_sym = k2.SymbolTable.from_str(
"""
a 1
b 2
f 3
o 4
r 5
z 6
SPN 7
"""
)
L.aux_labels_sym = k2.SymbolTable.from_str(
"""
foo 1
bar 2
baz 3
<UNK> 4
"""
)
ans = Lexicon.__new__(Lexicon)
ans.phones = L.labels_sym
ans.words = L.aux_labels_sym
ans.L_inv = k2.arc_sort(L.invert_())
ans.disambig_pattern = re.compile(r"^#\d+$")
return ans
@pytest.fixture
def compiler(lexicon):
return CtcTrainingGraphCompiler(lexicon, device=torch.device("cpu"))
class TestCtcTrainingGraphCompiler(object):
@staticmethod
def test_convert_transcript_to_fsa(compiler, lexicon):
texts = ["bar foo", "baz ok"]
fsa = compiler.convert_transcript_to_fsa(texts)
labels0 = fsa[0].labels[:-1].tolist()
aux_labels0 = fsa[0].aux_labels[:-1]
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
labels1 = fsa[1].labels[:-1].tolist()
aux_labels1 = fsa[1].aux_labels[:-1]
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
labels0 = [lexicon.phones[i] for i in labels0]
labels1 = [lexicon.phones[i] for i in labels1]
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
assert labels0 == ["b", "a", "r", "f", "o", "o"]
assert aux_labels0 == ["bar", "foo"]
assert labels1 == ["b", "a", "z", "SPN"]
assert aux_labels1 == ["baz", "<UNK>"]
@staticmethod
def test_compile(compiler, lexicon):
texts = ["bar foo", "baz ok"]
decoding_graph = compiler.compile(texts)
input1 = ["b", "b", "<blk>", "<blk>", "a", "a", "r", "<blk>", "<blk>"]
input1 += ["f", "f", "<blk>", "<blk>", "o", "o", "<blk>", "o", "o"]
input2 = ["b", "b", "a", "a", "a", "<blk>", "<blk>", "z", "z"]
input2 += ["<blk>", "<blk>", "SPN", "SPN", "<blk>", "<blk>"]
lexicon.phones._id2sym[0] == "<blk>"
lexicon.phones._sym2id["<blk>"] = 0
input1 = [lexicon.phones[i] for i in input1]
input2 = [lexicon.phones[i] for i in input2]
fsa1 = k2.linear_fsa(input1)
fsa2 = k2.linear_fsa(input2)
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
decoding_graph = k2.arc_sort(decoding_graph)
lattice = k2.intersect(
decoding_graph, fsas, treat_epsilons_specially=False
)
lattice = k2.connect(lattice)
aux_labels0 = lattice[0].aux_labels[:-1]
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
assert aux_labels0 == ["bar", "foo"]
aux_labels1 = lattice[1].aux_labels[:-1]
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
assert aux_labels1 == ["baz", "<UNK>"]
texts = get_texts(lattice)
texts = [[lexicon.words[i] for i in words] for words in texts]
assert texts == [["bar", "foo"], ["baz", "<UNK>"]]

93
test/test_utils.py Normal file
View File

@ -0,0 +1,93 @@
#!/usr/bin/env python3
import k2
import pytest
import torch
from icefall.utils import AttributeDict, encode_supervisions, get_texts
@pytest.fixture
def sup():
sequence_idx = torch.tensor([0, 1, 2])
start_frame = torch.tensor([1, 3, 9])
num_frames = torch.tensor([20, 30, 10])
text = ["one", "two", "three"]
return {
"sequence_idx": sequence_idx,
"start_frame": start_frame,
"num_frames": num_frames,
"text": text,
}
def test_encode_supervisions(sup):
supervision_segments, texts = encode_supervisions(sup, subsampling_factor=4)
assert torch.all(
torch.eq(
supervision_segments,
torch.tensor(
[[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
),
)
)
assert texts == ["two", "one", "three"]
def test_get_texts_ragged():
fsa1 = k2.Fsa.from_str(
"""
0 1 1 10
1 2 2 20
2 3 3 30
3 4 -1 0
4
"""
)
fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]")
fsa2 = k2.Fsa.from_str(
"""
0 1 1 1
1 2 2 2
2 3 -1 0
3
"""
)
fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]")
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
texts = get_texts(fsas)
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]
def test_get_texts_regular():
fsa1 = k2.Fsa.from_str(
"""
0 1 1 3 10
1 2 2 0 20
2 3 3 2 30
3 4 -1 -1 0
4
""",
num_aux_labels=1,
)
fsa2 = k2.Fsa.from_str(
"""
0 1 1 10 1
1 2 2 5 2
2 3 -1 -1 0
3
""",
num_aux_labels=1,
)
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
texts = get_texts(fsas)
assert texts == [[3, 2], [10, 5]]
def test_attribute_dict():
s = AttributeDict({"a": 10, "b": 20})
assert s.a == 10
assert s["b"] == 20
s.c = 100
assert s["c"] == 100