mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add CTC training.
This commit is contained in:
parent
a01d08f73c
commit
f3542c7793
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
data
|
||||
__pycache__
|
||||
path.sh
|
||||
exp
|
||||
|
@ -15,7 +15,7 @@ repos:
|
||||
rev: 5.9.2
|
||||
hooks:
|
||||
- id: isort
|
||||
args: [--profile=black]
|
||||
args: [--profile=black, --line-length=80]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
|
82
egs/librispeech/ASR/local/compile_hlg.py
Normal file
82
egs/librispeech/ASR/local/compile_hlg.py
Normal file
@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script compiles HLG from
|
||||
|
||||
- H, the ctc topology, built from phones contained in data/lang/lexicon.txt
|
||||
- L, the lexicon, built from data/lang/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in data/lm/HLG.pt
|
||||
"""
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def main():
|
||||
lexicon = Lexicon("data/lang")
|
||||
max_token_id = max(lexicon.tokens)
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt"))
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
|
||||
first_token_disambig_id = lexicon.phones["#0"]
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
print("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
print(f"LG shape: {LG.shape}")
|
||||
|
||||
print("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
print(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
print(type(LG.aux_labels))
|
||||
print("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
print(type(LG.aux_labels))
|
||||
|
||||
print("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
print("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedInt)
|
||||
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
print(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
|
||||
|
||||
print("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
print("Composing H and LG")
|
||||
HLG = k2.compose(H, LG, inner_labels="phones")
|
||||
|
||||
print("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
print("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
|
||||
print("Saving HLG.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -231,14 +231,18 @@ def add_self_loops(
|
||||
arcs:
|
||||
A list-of-list. The sublist contains
|
||||
`[src_state, dest_state, label, aux_label, score]`
|
||||
disambig_phone:
|
||||
It is the phone ID of the symbol `#0`.
|
||||
disambig_word:
|
||||
It is the word ID of the symbol `#0`.
|
||||
|
||||
Return:
|
||||
Return new `arcs` that contain self-loops.
|
||||
Return new `arcs` containing self-loops.
|
||||
"""
|
||||
states_needs_self_loops = set()
|
||||
for arc in arcs:
|
||||
src, dst, ilable, olable, score = arc
|
||||
if olable != 0:
|
||||
src, dst, ilabel, olabel, score = arc
|
||||
if olabel != 0:
|
||||
states_needs_self_loops.add(src)
|
||||
|
||||
ans = []
|
||||
@ -396,11 +400,11 @@ def main():
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
# Just for debugging, will remove it
|
||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||
|
||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||
|
@ -80,7 +80,18 @@ def test_read_lexicon(filename: str):
|
||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test_lexicon():
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
lexicon = Lexicon("data/lang")
|
||||
print(lexicon.tokens)
|
||||
|
||||
|
||||
def main():
|
||||
filename = generate_lexicon_file()
|
||||
test_read_lexicon(filename)
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_lexicon()
|
||||
|
@ -87,3 +87,24 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
|
||||
./local/prepare_lang.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
echo "Stage 6: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
if [ ! -e data/lm/G_3_gram.fst.txt ]; then
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
echo "Stage 7: Compile HLG"
|
||||
if [ ! -f data/lm/HLG.pt ]; then
|
||||
python3 ./local/compile_hlg.py
|
||||
fi
|
||||
fi
|
||||
|
14
egs/librispeech/ASR/tdnn_lstm_ctc/README.md
Normal file
14
egs/librispeech/ASR/tdnn_lstm_ctc/README.md
Normal file
@ -0,0 +1,14 @@
|
||||
## (To be filled in)
|
||||
|
||||
It will contain:
|
||||
|
||||
- How to run
|
||||
- WERs
|
||||
|
||||
```bash
|
||||
cd $PWD/..
|
||||
|
||||
./prepare.sh
|
||||
|
||||
./tdnn_lstm_ctc/train.py
|
||||
```
|
0
egs/librispeech/ASR/tdnn_lstm_ctc/__init__.py
Normal file
0
egs/librispeech/ASR/tdnn_lstm_ctc/__init__.py
Normal file
210
egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Executable file
210
egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Executable file
@ -0,0 +1,210 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model import TdnnLstm
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=9,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp3/"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 3,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
) -> List[Tuple[List[str], List[str]]]:
|
||||
"""Decode one batch and return a list of tuples containing
|
||||
`(ref_words, hyp_words)`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is the return value of :func:`get_params`.
|
||||
|
||||
|
||||
"""
|
||||
device = HLG.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
|
||||
|
||||
lattices = k2.intersect_dense_pruned(
|
||||
HLG,
|
||||
dense_fsa_vec,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
)
|
||||
|
||||
best_paths = k2.shortest_path(lattices, use_double_scores=True)
|
||||
|
||||
hyps = get_texts(best_paths)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
|
||||
texts = supervisions["text"]
|
||||
|
||||
results = []
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
results.append((ref_words, hyp_words))
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_phone_id = max(lexicon.tokens)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
model = TdnnLstm(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
tot_num_cuts = len(test_dl.dataset.cuts)
|
||||
num_cuts = 0
|
||||
|
||||
results = []
|
||||
for batch_idx, batch in enumerate(test_dl):
|
||||
this_batch = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
)
|
||||
results.extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
)
|
||||
|
||||
errs_filename = params.exp_dir / f"errs-{test_set}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
write_error_stats(f, test_set, results)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
86
egs/librispeech/ASR/tdnn_lstm_ctc/model.py
Normal file
86
egs/librispeech/ASR/tdnn_lstm_ctc/model.py
Normal file
@ -0,0 +1,86 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TdnnLstm(nn.Module):
|
||||
def __init__(
|
||||
self, num_features: int, num_classes: int, subsampling_factor: int = 3
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_features:
|
||||
The input dimension of the model.
|
||||
num_classes:
|
||||
The output dimension of the model.
|
||||
subsampling_factor:
|
||||
It reduces the number of output frames by this factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.num_classes = num_classes
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.tdnn = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_channels=num_features,
|
||||
out_channels=500,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=500, affine=False),
|
||||
nn.Conv1d(
|
||||
in_channels=500,
|
||||
out_channels=500,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=500, affine=False),
|
||||
nn.Conv1d(
|
||||
in_channels=500,
|
||||
out_channels=500,
|
||||
kernel_size=3,
|
||||
stride=self.subsampling_factor, # stride: subsampling_factor!
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=500, affine=False),
|
||||
)
|
||||
self.lstms = nn.ModuleList(
|
||||
[
|
||||
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
|
||||
for _ in range(5)
|
||||
]
|
||||
)
|
||||
self.lstm_bnorms = nn.ModuleList(
|
||||
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||
)
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
self.linear = nn.Linear(in_features=500, out_features=self.num_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, C, T]
|
||||
|
||||
Returns:
|
||||
The output tensor has shape [N, T, C]
|
||||
"""
|
||||
x = self.tdnn(x)
|
||||
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
|
||||
for lstm, bnorm in zip(self.lstms, self.lstm_bnorms):
|
||||
x_new, _ = lstm(x)
|
||||
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
||||
x_new = self.dropout(x_new)
|
||||
x = x_new + x # skip connections
|
||||
x = x.transpose(
|
||||
1, 0
|
||||
) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim
|
||||
x = self.linear(x)
|
||||
x = nn.functional.log_softmax(x, dim=-1)
|
||||
return x
|
493
egs/librispeech/ASR/tdnn_lstm_ctc/train.py
Executable file
493
egs/librispeech/ASR/tdnn_lstm_ctc/train.py
Executable file
@ -0,0 +1,493 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This is just at the very beginning ...
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from model import TdnnLstm
|
||||
from torch.nn.utils import clip_grad_value_
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import AttributeDict, encode_supervisions, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
# TODO: add extra arguments and support DDP training.
|
||||
# Currently, only single GPU training is implemented. Will add
|
||||
# DDP training once single GPU training is finished.
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
is saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- exp_dir: It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
||||
- lang_dir: It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
|
||||
- lr: It specifies the initial learning rate
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||
and continue training from that checkpoint.
|
||||
|
||||
- num_epochs: Number of epochs to train.
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||
|
||||
- beam_size: It is used in k2.ctc_loss
|
||||
|
||||
- reduction: It is used in k2.ctc_loss
|
||||
|
||||
- use_double_scores: It is used in k2.ctc_loss
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"lr": 1e-3,
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 5e-4,
|
||||
"subsampling_factor": 3,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 10,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"valid_interval": 1000,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
):
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of TdnnLstm in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
graph_compiler:
|
||||
It is used to build a decoding graph from a ctc topo and training
|
||||
transcript. The training transcript is contained in the given `batch`,
|
||||
while the ctc topo is built when this compiler is instantiated.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is [N, T, C]
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
supervisions = batch["supervisions"]
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction=params.reduction,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
tot_frames += params.valid_frames
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
graph_compiler:
|
||||
It is used to convert transcripts to FSAs.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_value_(model.parameters(), 5.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss}, "
|
||||
f"best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
|
||||
params.train_loss = tot_loss / tot_frames
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_phone_id = max(lexicon.tokens)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||
|
||||
model = TdnnLstm(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
||||
|
||||
load_checkpoint_if_available(
|
||||
params=params, model=model, optimizer=optimizer
|
||||
)
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
if epoch > params.start_epoch:
|
||||
logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}")
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/lr",
|
||||
scheduler.get_last_lr()[0],
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
save_checkpoint(
|
||||
params=params, model=model, optimizer=optimizer, scheduler=scheduler
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
131
icefall/checkpoint.py
Normal file
131
icefall/checkpoint.py
Normal file
@ -0,0 +1,131 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save training information to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
The checkpoint filename.
|
||||
model:
|
||||
The model to be saved. We only save its `state_dict()`.
|
||||
params:
|
||||
User defined parameters, e.g., epoch, loss.
|
||||
optimizer:
|
||||
The optimizer to be saved. We only save its `state_dict()`.
|
||||
scheduler:
|
||||
The scheduler to be saved. We only save its `state_dict()`.
|
||||
scalar:
|
||||
The GradScaler to be saved. We only save its `state_dict()`.
|
||||
rank:
|
||||
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
logging.info(f"Saving checkpoint to {filename}")
|
||||
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
}
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
filename: Path,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[_LRScheduler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
TODO: document it
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {filename}")
|
||||
checkpoint = torch.load(filename, map_location="cpu")
|
||||
|
||||
if next(iter(checkpoint["model"])).startswith("module."):
|
||||
logging.info("Loading checkpoint saved by DDP")
|
||||
|
||||
dst_state_dict = model.state_dict()
|
||||
src_state_dict = checkpoint["model"]
|
||||
for key in dst_state_dict.keys():
|
||||
src_key = "{}.{}".format("module", key)
|
||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||
assert len(src_state_dict) == 0
|
||||
model.load_state_dict(dst_state_dict, strict=False)
|
||||
else:
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
|
||||
checkpoint.pop("model")
|
||||
|
||||
def load(name, obj):
|
||||
s = checkpoint[name]
|
||||
if obj and s:
|
||||
obj.load_state_dict(s)
|
||||
checkpoint.pop(name)
|
||||
|
||||
load("optimizer", optimizer)
|
||||
load("scheduler", scheduler)
|
||||
load("grad_scaler", scaler)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def average_checkpoints(filenames: List[Path]) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
Filenames of the checkpoints to be averaged. We assume all
|
||||
checkpoints are saved by :func:`save_checkpoint`.
|
||||
Returns:
|
||||
Return a dict (i.e., state_dict) which is the average of all
|
||||
model state dicts contained in the checkpoints.
|
||||
"""
|
||||
n = len(filenames)
|
||||
|
||||
avg = torch.load(filenames[0], map_location="cpu")["model"]
|
||||
for i in range(1, n):
|
||||
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
|
||||
for k in avg:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in avg:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
avg[k] //= n
|
||||
|
||||
return avg
|
0
icefall/dataset/__init__.py
Normal file
0
icefall/dataset/__init__.py
Normal file
248
icefall/dataset/asr_datamodule.py
Normal file
248
icefall/dataset/asr_datamodule.py
Normal file
@ -0,0 +1,248 @@
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
K2SpeechRecognitionDataset,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class AsrDataModule(DataModule):
|
||||
"""
|
||||
DataModule for K2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
super().add_arguments(parser)
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--feature-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=500.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the BucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
|
||||
def train_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = self.train_cuts()
|
||||
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = [
|
||||
SpecAugment(
|
||||
num_frame_masks=2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
]
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cuts_train,
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
cuts_train = cuts_train.drop_features()
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cuts=cuts_train,
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=True,
|
||||
num_buckets=self.args.num_buckets,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=True,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = self.valid_cuts()
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
cuts_valid = cuts_valid.drop_features()
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cuts_valid.drop_features(),
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(cuts_valid)
|
||||
valid_sampler = SingleCutSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=True,
|
||||
)
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
cuts = self.test_cuts()
|
||||
is_list = isinstance(cuts, list)
|
||||
test_loaders = []
|
||||
if not is_list:
|
||||
cuts = [cuts]
|
||||
|
||||
for cuts_test in cuts:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
cuts_test,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
)
|
||||
sampler = SingleCutSampler(
|
||||
cuts_test, max_duration=self.args.max_duration
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=1
|
||||
)
|
||||
test_loaders.append(test_dl)
|
||||
|
||||
if is_list:
|
||||
return test_loaders
|
||||
else:
|
||||
return test_loaders[0]
|
43
icefall/dataset/datamodule.py
Normal file
43
icefall/dataset/datamodule.py
Normal file
@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import CutSet
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class DataModule:
|
||||
"""
|
||||
Contains dataset-related code. It is intended to read/construct Lhotse cuts,
|
||||
and create Dataset/Sampler/DataLoader out of them.
|
||||
|
||||
There is a separate method to create each of train/valid/test DataLoader.
|
||||
In principle, there might be multiple DataLoaders for each of
|
||||
train/valid/test
|
||||
(e.g. when a corpus has multiple test sets).
|
||||
The API of this class allows to return lists of CutSets/DataLoaders.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
pass
|
||||
|
||||
def train_cuts(self) -> Union[CutSet, List[CutSet]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def valid_cuts(self) -> Union[CutSet, List[CutSet]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def test_cuts(self) -> Union[CutSet, List[CutSet]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def train_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def valid_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
raise NotImplementedError()
|
68
icefall/dataset/librispeech.py
Normal file
68
icefall/dataset/librispeech.py
Normal file
@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from lhotse import CutSet, load_manifest
|
||||
|
||||
from icefall.dataset.asr_datamodule import AsrDataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule(AsrDataModule):
|
||||
"""
|
||||
LibriSpeech ASR data module. Can be used for 100h subset
|
||||
(``--full-libri false``) or full 960h set.
|
||||
The train and valid cuts for standard Libri splits are
|
||||
concatenated into a single CutSet/DataLoader.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
super().add_arguments(parser)
|
||||
group = parser.add_argument_group(title="LibriSpeech specific options")
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech.",
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
||||
)
|
||||
if self.args.full_libri:
|
||||
cuts_train = (
|
||||
cuts_train
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
||||
)
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
||||
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
cuts = []
|
||||
for test_set in test_sets:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts.append(
|
||||
load_manifest(
|
||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts
|
109
icefall/graph_compiler.py
Normal file
109
icefall/graph_compiler.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
class CtcTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lexicon: Lexicon,
|
||||
device: torch.device,
|
||||
oov: str = "<UNK>",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lexicon:
|
||||
It is built from `data/lang/lexicon.txt`.
|
||||
device:
|
||||
The device to use for operations compiling transcripts to FSAs.
|
||||
oov:
|
||||
Out of vocabulary word. When a word in the transcript
|
||||
does not exist in the lexicon, it is replaced with `oov`.
|
||||
"""
|
||||
L_inv = lexicon.L_inv.to(device)
|
||||
assert L_inv.requires_grad is False
|
||||
|
||||
assert oov in lexicon.words
|
||||
|
||||
self.L_inv = k2.arc_sort(L_inv)
|
||||
self.oov_id = lexicon.words[oov]
|
||||
self.words = lexicon.words
|
||||
|
||||
max_token_id = max(lexicon.tokens)
|
||||
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
|
||||
|
||||
self.ctc_topo = ctc_topo.to(device)
|
||||
self.device = device
|
||||
|
||||
def compile(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Build decoding graphs by composing ctc_topo with
|
||||
given transcripts.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of strings. Each string contains a sentence for an utterance.
|
||||
A sentence consists of spaces separated words. An example `texts`
|
||||
looks like:
|
||||
|
||||
['hello icefall', 'CTC training with k2']
|
||||
|
||||
Returns:
|
||||
An FsaVec, the composition result of `self.ctc_topo` and the
|
||||
transcript FSA.
|
||||
"""
|
||||
transcript_fsa = self.convert_transcript_to_fsa(texts)
|
||||
|
||||
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
|
||||
# is False, so we add epsilon self-loops here
|
||||
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
transcript_fsa
|
||||
)
|
||||
|
||||
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
||||
|
||||
decoding_graph = k2.compose(
|
||||
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
|
||||
)
|
||||
|
||||
assert decoding_graph.requires_grad is False
|
||||
|
||||
return decoding_graph
|
||||
|
||||
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Convert a list of transcript texts to an FsaVec.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of strings. Each string contains a sentence for an utterance.
|
||||
A sentence consists of spaces separated words. An example `texts`
|
||||
looks like:
|
||||
|
||||
['hello icefall', 'CTC training with k2']
|
||||
|
||||
Returns:
|
||||
Return an FsaVec, whose `shape[0]` equals to `len(texts)`.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split(" "):
|
||||
if word in self.words:
|
||||
word_ids.append(self.words[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
|
||||
word_fsa = k2.linear_fsa(word_ids_list, self.device)
|
||||
|
||||
word_fsa_with_self_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||
|
||||
fsa = k2.intersect(
|
||||
self.L_inv, word_fsa_with_self_loops, treat_epsilons_specially=False
|
||||
)
|
||||
# fsa has word ID as labels and token ID as aux_labels, so
|
||||
# we need to invert it
|
||||
ans_fsa = fsa.invert_()
|
||||
return k2.arc_sort(ans_fsa)
|
66
icefall/lexicon.py
Normal file
66
icefall/lexicon.py
Normal file
@ -0,0 +1,66 @@
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
|
||||
class Lexicon(object):
|
||||
"""Phone based lexicon.
|
||||
|
||||
TODO: Add BpeLexicon for BPE models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
Path to the lang director. It is expected to contain the following
|
||||
files:
|
||||
- phones.txt
|
||||
- words.txt
|
||||
- L.pt
|
||||
The above files are produced by the script `prepare.sh`. You
|
||||
should have run that before running the training code.
|
||||
disambig_pattern:
|
||||
It contains the pattern for disambiguation symbols.
|
||||
"""
|
||||
lang_dir = Path(lang_dir)
|
||||
self.phones = k2.SymbolTable.from_file(lang_dir / "phones.txt")
|
||||
self.words = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
if (lang_dir / "Linv.pt").exists():
|
||||
logging.info("Loading pre-compiled Linv.pt")
|
||||
L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
|
||||
else:
|
||||
logging.info("Converting L.pt to Linv.pt")
|
||||
L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt"))
|
||||
L_inv = k2.arc_sort(L.invert())
|
||||
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
||||
|
||||
# We save L_inv instead of L because it will be used to intersect with
|
||||
# transcript, both of whose labels are word IDs.
|
||||
self.L_inv = L_inv
|
||||
self.disambig_pattern = disambig_pattern
|
||||
|
||||
@property
|
||||
def tokens(self) -> List[int]:
|
||||
"""Return a list of phone IDs excluding those from
|
||||
disambiguation symbols.
|
||||
|
||||
Caution:
|
||||
0 is not a phone ID so it is excluded from the return value.
|
||||
"""
|
||||
symbols = self.phones.symbols
|
||||
ans = []
|
||||
for s in symbols:
|
||||
if not self.disambig_pattern.match(s):
|
||||
ans.append(self.phones[s])
|
||||
if 0 in ans:
|
||||
ans.remove(0)
|
||||
ans.sort()
|
||||
return ans
|
298
icefall/utils.py
298
icefall/utils.py
@ -1,5 +1,20 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, TextIO, Tuple, Union
|
||||
|
||||
import k2
|
||||
import k2.ragged as k2r
|
||||
import kaldialign
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -32,3 +47,286 @@ def get_executor():
|
||||
# No need to return anything - compute_and_store_features
|
||||
# will just instantiate the pool itself.
|
||||
yield None
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||
that a type is a bool type and user can enter
|
||||
|
||||
- yes, true, t, y, 1, to represent True
|
||||
- no, false, f, n, 0, to represent False
|
||||
|
||||
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
||||
"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def setup_logger(
|
||||
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
||||
) -> None:
|
||||
"""Setup log level.
|
||||
|
||||
Args:
|
||||
log_filename:
|
||||
The filename to save the log.
|
||||
log_level:
|
||||
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||
"critical"
|
||||
"""
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
world_size = dist.get_world_size()
|
||||
rank = dist.get_rank()
|
||||
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||
else:
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
log_filename = f"{log_filename}-{date_time}"
|
||||
|
||||
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||
|
||||
level = logging.ERROR
|
||||
if log_level == "debug":
|
||||
level = logging.DEBUG
|
||||
elif log_level == "info":
|
||||
level = logging.INFO
|
||||
elif log_level == "warning":
|
||||
level = logging.WARNING
|
||||
elif log_level == "critical":
|
||||
level = logging.CRITICAL
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename, format=formatter, level=level, filemode="w"
|
||||
)
|
||||
if use_console:
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(level)
|
||||
console.setFormatter(logging.Formatter(formatter))
|
||||
logging.getLogger("").addHandler(console)
|
||||
|
||||
|
||||
def get_env_info():
|
||||
"""
|
||||
TODO:
|
||||
"""
|
||||
return {
|
||||
"k2-git-sha1": None,
|
||||
"k2-version": None,
|
||||
"lhotse-version": None,
|
||||
"torch-version": None,
|
||||
"icefall-sha1": None,
|
||||
"icefall-version": None,
|
||||
}
|
||||
|
||||
|
||||
# See
|
||||
# https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute # noqa
|
||||
class AttributeDict(dict):
|
||||
__slots__ = ()
|
||||
__getattr__ = dict.__getitem__
|
||||
__setattr__ = dict.__setitem__
|
||||
|
||||
|
||||
def encode_supervisions(
|
||||
supervisions: Dict[str, torch.Tensor], subsampling_factor: int
|
||||
) -> Tuple[torch.Tensor, List[str]]:
|
||||
"""
|
||||
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
|
||||
and a list of transcription strings.
|
||||
|
||||
The supervision tensor has shape ``(batch_size, 3)``.
|
||||
Its second dimension contains information about sequence index [0],
|
||||
start frames [1] and num frames [2].
|
||||
|
||||
The batch items might become re-ordered during this operation -- the
|
||||
returned tensor and list of strings are guaranteed to be consistent with
|
||||
each other.
|
||||
"""
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // subsampling_factor,
|
||||
supervisions["num_frames"] // subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
||||
supervision_segments = supervision_segments[indices]
|
||||
texts = supervisions["text"]
|
||||
texts = [texts[idx] for idx in indices]
|
||||
|
||||
return supervision_segments, texts
|
||||
|
||||
|
||||
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
"""Extract the texts from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the label sequences we
|
||||
decoded.
|
||||
"""
|
||||
if isinstance(best_paths.aux_labels, k2.RaggedInt):
|
||||
# remove 0's and -1's.
|
||||
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
|
||||
aux_shape = k2r.compose_ragged_shapes(
|
||||
best_paths.arcs.shape(), aux_labels.shape()
|
||||
)
|
||||
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = k2r.remove_axis(aux_shape, 1)
|
||||
aux_shape = k2r.remove_axis(aux_shape, 1)
|
||||
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
|
||||
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
|
||||
# remove 0's and -1's.
|
||||
aux_labels = k2r.remove_values_leq(aux_labels, 0)
|
||||
|
||||
assert aux_labels.num_axes() == 2
|
||||
return k2r.to_list(aux_labels)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
|
||||
) -> float:
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
for ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
" ".join(
|
||||
(
|
||||
ref_word
|
||||
if ref_word == hyp_word
|
||||
else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted(
|
||||
[(v, k) for k, v in subs.items()], reverse=True
|
||||
):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print(
|
||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||
)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
50
test/test_checkpoint.py
Normal file
50
test/test_checkpoint.py
Normal file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
load_checkpoint,
|
||||
save_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def checkpoints1(tmp_path):
|
||||
f = tmp_path / "f.pt"
|
||||
m = nn.Module()
|
||||
m.p1 = nn.Parameter(torch.tensor([10.0, 20.0]), requires_grad=False)
|
||||
m.register_buffer("p2", torch.tensor([10, 100]))
|
||||
|
||||
params = {"a": 10, "b": 20}
|
||||
save_checkpoint(f, m, params=params)
|
||||
return f
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def checkpoints2(tmp_path):
|
||||
f = tmp_path / "f2.pt"
|
||||
m = nn.Module()
|
||||
m.p1 = nn.Parameter(torch.Tensor([50, 30.0]))
|
||||
m.register_buffer("p2", torch.tensor([1, 3]))
|
||||
params = {"a": 100, "b": 200}
|
||||
|
||||
save_checkpoint(f, m, params=params)
|
||||
return f
|
||||
|
||||
|
||||
def test_load_checkpoints(checkpoints1):
|
||||
m = nn.Module()
|
||||
m.p1 = nn.Parameter(torch.Tensor([0, 0.0]))
|
||||
m.p2 = nn.Parameter(torch.Tensor([0, 0]))
|
||||
params = load_checkpoint(checkpoints1, m)
|
||||
assert torch.allclose(m.p1, torch.Tensor([10.0, 20]))
|
||||
assert params == {"a": 10, "b": 20}
|
||||
|
||||
|
||||
def test_average_checkpoints(checkpoints1, checkpoints2):
|
||||
state_dict = average_checkpoints([checkpoints1, checkpoints2])
|
||||
assert torch.allclose(state_dict["p1"], torch.Tensor([30, 25.0]))
|
||||
assert torch.allclose(state_dict["p2"], torch.tensor([5, 51]))
|
160
test/test_graph_compiler.py
Normal file
160
test/test_graph_compiler.py
Normal file
@ -0,0 +1,160 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
import re
|
||||
|
||||
import k2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lexicon():
|
||||
"""
|
||||
We use the following test data:
|
||||
|
||||
lexicon.txt
|
||||
|
||||
foo f o o
|
||||
bar b a r
|
||||
baz b a z
|
||||
<UNK> SPN
|
||||
|
||||
phones.txt
|
||||
|
||||
<eps> 0
|
||||
a 1
|
||||
b 2
|
||||
f 3
|
||||
o 4
|
||||
r 5
|
||||
z 6
|
||||
SPN 7
|
||||
|
||||
words.txt:
|
||||
|
||||
<eps> 0
|
||||
foo 1
|
||||
bar 2
|
||||
baz 3
|
||||
<UNK> 4
|
||||
"""
|
||||
L = k2.Fsa.from_str(
|
||||
"""
|
||||
0 0 7 4 0
|
||||
0 7 -1 -1 0
|
||||
0 1 3 1 0
|
||||
0 3 2 2 0
|
||||
0 5 2 3 0
|
||||
1 2 4 0 0
|
||||
2 0 4 0 0
|
||||
3 4 1 0 0
|
||||
4 0 5 0 0
|
||||
5 6 1 0 0
|
||||
6 0 6 0 0
|
||||
7
|
||||
""",
|
||||
num_aux_labels=1,
|
||||
)
|
||||
L.labels_sym = k2.SymbolTable.from_str(
|
||||
"""
|
||||
a 1
|
||||
b 2
|
||||
f 3
|
||||
o 4
|
||||
r 5
|
||||
z 6
|
||||
SPN 7
|
||||
"""
|
||||
)
|
||||
L.aux_labels_sym = k2.SymbolTable.from_str(
|
||||
"""
|
||||
foo 1
|
||||
bar 2
|
||||
baz 3
|
||||
<UNK> 4
|
||||
"""
|
||||
)
|
||||
ans = Lexicon.__new__(Lexicon)
|
||||
ans.phones = L.labels_sym
|
||||
ans.words = L.aux_labels_sym
|
||||
ans.L_inv = k2.arc_sort(L.invert_())
|
||||
ans.disambig_pattern = re.compile(r"^#\d+$")
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def compiler(lexicon):
|
||||
return CtcTrainingGraphCompiler(lexicon, device=torch.device("cpu"))
|
||||
|
||||
|
||||
class TestCtcTrainingGraphCompiler(object):
|
||||
@staticmethod
|
||||
def test_convert_transcript_to_fsa(compiler, lexicon):
|
||||
texts = ["bar foo", "baz ok"]
|
||||
fsa = compiler.convert_transcript_to_fsa(texts)
|
||||
labels0 = fsa[0].labels[:-1].tolist()
|
||||
aux_labels0 = fsa[0].aux_labels[:-1]
|
||||
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
|
||||
|
||||
labels1 = fsa[1].labels[:-1].tolist()
|
||||
aux_labels1 = fsa[1].aux_labels[:-1]
|
||||
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
|
||||
|
||||
labels0 = [lexicon.phones[i] for i in labels0]
|
||||
labels1 = [lexicon.phones[i] for i in labels1]
|
||||
|
||||
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
|
||||
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
|
||||
|
||||
assert labels0 == ["b", "a", "r", "f", "o", "o"]
|
||||
assert aux_labels0 == ["bar", "foo"]
|
||||
|
||||
assert labels1 == ["b", "a", "z", "SPN"]
|
||||
assert aux_labels1 == ["baz", "<UNK>"]
|
||||
|
||||
@staticmethod
|
||||
def test_compile(compiler, lexicon):
|
||||
texts = ["bar foo", "baz ok"]
|
||||
decoding_graph = compiler.compile(texts)
|
||||
input1 = ["b", "b", "<blk>", "<blk>", "a", "a", "r", "<blk>", "<blk>"]
|
||||
input1 += ["f", "f", "<blk>", "<blk>", "o", "o", "<blk>", "o", "o"]
|
||||
|
||||
input2 = ["b", "b", "a", "a", "a", "<blk>", "<blk>", "z", "z"]
|
||||
input2 += ["<blk>", "<blk>", "SPN", "SPN", "<blk>", "<blk>"]
|
||||
|
||||
lexicon.phones._id2sym[0] == "<blk>"
|
||||
lexicon.phones._sym2id["<blk>"] = 0
|
||||
|
||||
input1 = [lexicon.phones[i] for i in input1]
|
||||
input2 = [lexicon.phones[i] for i in input2]
|
||||
|
||||
fsa1 = k2.linear_fsa(input1)
|
||||
fsa2 = k2.linear_fsa(input2)
|
||||
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
||||
|
||||
decoding_graph = k2.arc_sort(decoding_graph)
|
||||
lattice = k2.intersect(
|
||||
decoding_graph, fsas, treat_epsilons_specially=False
|
||||
)
|
||||
lattice = k2.connect(lattice)
|
||||
|
||||
aux_labels0 = lattice[0].aux_labels[:-1]
|
||||
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
|
||||
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
|
||||
assert aux_labels0 == ["bar", "foo"]
|
||||
|
||||
aux_labels1 = lattice[1].aux_labels[:-1]
|
||||
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
|
||||
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
|
||||
assert aux_labels1 == ["baz", "<UNK>"]
|
||||
|
||||
texts = get_texts(lattice)
|
||||
texts = [[lexicon.words[i] for i in words] for words in texts]
|
||||
assert texts == [["bar", "foo"], ["baz", "<UNK>"]]
|
93
test/test_utils.py
Normal file
93
test/test_utils.py
Normal file
@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
import k2
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from icefall.utils import AttributeDict, encode_supervisions, get_texts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sup():
|
||||
sequence_idx = torch.tensor([0, 1, 2])
|
||||
start_frame = torch.tensor([1, 3, 9])
|
||||
num_frames = torch.tensor([20, 30, 10])
|
||||
text = ["one", "two", "three"]
|
||||
return {
|
||||
"sequence_idx": sequence_idx,
|
||||
"start_frame": start_frame,
|
||||
"num_frames": num_frames,
|
||||
"text": text,
|
||||
}
|
||||
|
||||
|
||||
def test_encode_supervisions(sup):
|
||||
supervision_segments, texts = encode_supervisions(sup, subsampling_factor=4)
|
||||
assert torch.all(
|
||||
torch.eq(
|
||||
supervision_segments,
|
||||
torch.tensor(
|
||||
[[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
|
||||
),
|
||||
)
|
||||
)
|
||||
assert texts == ["two", "one", "three"]
|
||||
|
||||
|
||||
def test_get_texts_ragged():
|
||||
fsa1 = k2.Fsa.from_str(
|
||||
"""
|
||||
0 1 1 10
|
||||
1 2 2 20
|
||||
2 3 3 30
|
||||
3 4 -1 0
|
||||
4
|
||||
"""
|
||||
)
|
||||
fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]")
|
||||
|
||||
fsa2 = k2.Fsa.from_str(
|
||||
"""
|
||||
0 1 1 1
|
||||
1 2 2 2
|
||||
2 3 -1 0
|
||||
3
|
||||
"""
|
||||
)
|
||||
fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]")
|
||||
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
||||
texts = get_texts(fsas)
|
||||
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]
|
||||
|
||||
|
||||
def test_get_texts_regular():
|
||||
fsa1 = k2.Fsa.from_str(
|
||||
"""
|
||||
0 1 1 3 10
|
||||
1 2 2 0 20
|
||||
2 3 3 2 30
|
||||
3 4 -1 -1 0
|
||||
4
|
||||
""",
|
||||
num_aux_labels=1,
|
||||
)
|
||||
|
||||
fsa2 = k2.Fsa.from_str(
|
||||
"""
|
||||
0 1 1 10 1
|
||||
1 2 2 5 2
|
||||
2 3 -1 -1 0
|
||||
3
|
||||
""",
|
||||
num_aux_labels=1,
|
||||
)
|
||||
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
||||
texts = get_texts(fsas)
|
||||
assert texts == [[3, 2], [10, 5]]
|
||||
|
||||
|
||||
def test_attribute_dict():
|
||||
s = AttributeDict({"a": 10, "b": 20})
|
||||
assert s.a == 10
|
||||
assert s["b"] == 20
|
||||
s.c = 100
|
||||
assert s["c"] == 100
|
Loading…
x
Reference in New Issue
Block a user