mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Add CTC training.
This commit is contained in:
parent
a01d08f73c
commit
f3542c7793
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
data
|
data
|
||||||
__pycache__
|
__pycache__
|
||||||
path.sh
|
path.sh
|
||||||
|
exp
|
||||||
|
@ -15,7 +15,7 @@ repos:
|
|||||||
rev: 5.9.2
|
rev: 5.9.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: [--profile=black]
|
args: [--profile=black, --line-length=80]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.0.1
|
rev: v4.0.1
|
||||||
|
82
egs/librispeech/ASR/local/compile_hlg.py
Normal file
82
egs/librispeech/ASR/local/compile_hlg.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script compiles HLG from
|
||||||
|
|
||||||
|
- H, the ctc topology, built from phones contained in data/lang/lexicon.txt
|
||||||
|
- L, the lexicon, built from data/lang/L_disambig.pt
|
||||||
|
|
||||||
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
|
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||||
|
|
||||||
|
The generated HLG is saved in data/lm/HLG.pt
|
||||||
|
"""
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
lexicon = Lexicon("data/lang")
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
H = k2.ctc_topo(max_token_id)
|
||||||
|
L = k2.Fsa.from_dict(torch.load("data/lang/L_disambig.pt"))
|
||||||
|
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
|
||||||
|
first_token_disambig_id = lexicon.phones["#0"]
|
||||||
|
first_word_disambig_id = lexicon.words["#0"]
|
||||||
|
|
||||||
|
L = k2.arc_sort(L)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
|
||||||
|
print("Intersecting L and G")
|
||||||
|
LG = k2.compose(L, G)
|
||||||
|
print(f"LG shape: {LG.shape}")
|
||||||
|
|
||||||
|
print("Connecting LG")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
print(f"LG shape after k2.connect: {LG.shape}")
|
||||||
|
|
||||||
|
print(type(LG.aux_labels))
|
||||||
|
print("Determinizing LG")
|
||||||
|
|
||||||
|
LG = k2.determinize(LG)
|
||||||
|
print(type(LG.aux_labels))
|
||||||
|
|
||||||
|
print("Connecting LG after k2.determinize")
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
|
||||||
|
print("Removing disambiguation symbols on LG")
|
||||||
|
|
||||||
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
|
||||||
|
assert isinstance(LG.aux_labels, k2.RaggedInt)
|
||||||
|
LG.aux_labels.values()[LG.aux_labels.values() >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
|
LG = k2.remove_epsilon(LG)
|
||||||
|
print(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
|
||||||
|
LG = k2.connect(LG)
|
||||||
|
LG.aux_labels = k2.ragged.remove_values_eq(LG.aux_labels, 0)
|
||||||
|
|
||||||
|
print("Arc sorting LG")
|
||||||
|
LG = k2.arc_sort(LG)
|
||||||
|
|
||||||
|
print("Composing H and LG")
|
||||||
|
HLG = k2.compose(H, LG, inner_labels="phones")
|
||||||
|
|
||||||
|
print("Connecting LG")
|
||||||
|
HLG = k2.connect(HLG)
|
||||||
|
|
||||||
|
print("Arc sorting LG")
|
||||||
|
HLG = k2.arc_sort(HLG)
|
||||||
|
|
||||||
|
print("Saving HLG.pt to data/lm")
|
||||||
|
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -231,14 +231,18 @@ def add_self_loops(
|
|||||||
arcs:
|
arcs:
|
||||||
A list-of-list. The sublist contains
|
A list-of-list. The sublist contains
|
||||||
`[src_state, dest_state, label, aux_label, score]`
|
`[src_state, dest_state, label, aux_label, score]`
|
||||||
|
disambig_phone:
|
||||||
|
It is the phone ID of the symbol `#0`.
|
||||||
|
disambig_word:
|
||||||
|
It is the word ID of the symbol `#0`.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
Return new `arcs` that contain self-loops.
|
Return new `arcs` containing self-loops.
|
||||||
"""
|
"""
|
||||||
states_needs_self_loops = set()
|
states_needs_self_loops = set()
|
||||||
for arc in arcs:
|
for arc in arcs:
|
||||||
src, dst, ilable, olable, score = arc
|
src, dst, ilabel, olabel, score = arc
|
||||||
if olable != 0:
|
if olabel != 0:
|
||||||
states_needs_self_loops.add(src)
|
states_needs_self_loops.add(src)
|
||||||
|
|
||||||
ans = []
|
ans = []
|
||||||
@ -396,11 +400,11 @@ def main():
|
|||||||
sil_prob=sil_prob,
|
sil_prob=sil_prob,
|
||||||
need_self_loops=True,
|
need_self_loops=True,
|
||||||
)
|
)
|
||||||
|
# Just for debugging, will remove it
|
||||||
|
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||||
|
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
# Just for debugging, will remove it
|
|
||||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
|
||||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
|
||||||
|
|
||||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
|
L.labels_sym = k2.SymbolTable.from_file(out_dir / "phones.txt")
|
||||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||||
|
@ -80,7 +80,18 @@ def test_read_lexicon(filename: str):
|
|||||||
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
fsa_disambig.draw("L_disambig.pdf", title="L_disambig")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def test_lexicon():
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
lexicon = Lexicon("data/lang")
|
||||||
|
print(lexicon.tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
filename = generate_lexicon_file()
|
filename = generate_lexicon_file()
|
||||||
test_read_lexicon(filename)
|
test_read_lexicon(filename)
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_lexicon()
|
||||||
|
@ -87,3 +87,24 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
|
|
||||||
./local/prepare_lang.py
|
./local/prepare_lang.py
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
|
echo "Stage 6: Prepare G"
|
||||||
|
# We assume you have install kaldilm, if not, please install
|
||||||
|
# it using: pip install kaldilm
|
||||||
|
|
||||||
|
if [ ! -e data/lm/G_3_gram.fst.txt ]; then
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="data/lang/words.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=3 \
|
||||||
|
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
echo "Stage 7: Compile HLG"
|
||||||
|
if [ ! -f data/lm/HLG.pt ]; then
|
||||||
|
python3 ./local/compile_hlg.py
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
14
egs/librispeech/ASR/tdnn_lstm_ctc/README.md
Normal file
14
egs/librispeech/ASR/tdnn_lstm_ctc/README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
## (To be filled in)
|
||||||
|
|
||||||
|
It will contain:
|
||||||
|
|
||||||
|
- How to run
|
||||||
|
- WERs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd $PWD/..
|
||||||
|
|
||||||
|
./prepare.sh
|
||||||
|
|
||||||
|
./tdnn_lstm_ctc/train.py
|
||||||
|
```
|
0
egs/librispeech/ASR/tdnn_lstm_ctc/__init__.py
Normal file
0
egs/librispeech/ASR/tdnn_lstm_ctc/__init__.py
Normal file
210
egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Executable file
210
egs/librispeech/ASR/tdnn_lstm_ctc/decode.py
Executable file
@ -0,0 +1,210 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from model import TdnnLstm
|
||||||
|
|
||||||
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
|
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
|
setup_logger,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=9,
|
||||||
|
help="It specifies the checkpoint to use for decoding."
|
||||||
|
"Note: Epoch counts from 0.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch'. ",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"exp_dir": Path("tdnn_lstm_ctc/exp3/"),
|
||||||
|
"lang_dir": Path("data/lang"),
|
||||||
|
"feature_dim": 80,
|
||||||
|
"subsampling_factor": 3,
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: k2.Fsa,
|
||||||
|
batch: dict,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
) -> List[Tuple[List[str], List[str]]]:
|
||||||
|
"""Decode one batch and return a list of tuples containing
|
||||||
|
`(ref_words, hyp_words)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is the return value of :func:`get_params`.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
device = HLG.device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is [N, T, C]
|
||||||
|
|
||||||
|
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||||
|
|
||||||
|
nnet_output = model(feature)
|
||||||
|
# nnet_output is [N, T, C]
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
supervisions["start_frame"] // params.subsampling_factor,
|
||||||
|
supervisions["num_frames"] // params.subsampling_factor,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
|
||||||
|
|
||||||
|
lattices = k2.intersect_dense_pruned(
|
||||||
|
HLG,
|
||||||
|
dense_fsa_vec,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_paths = k2.shortest_path(lattices, use_double_scores=True)
|
||||||
|
|
||||||
|
hyps = get_texts(best_paths)
|
||||||
|
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||||
|
|
||||||
|
texts = supervisions["text"]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
results.append((ref_words, hyp_words))
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_phone_id = max(lexicon.tokens)
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
|
||||||
|
HLG = HLG.to(device)
|
||||||
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
model = TdnnLstm(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
if params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
# CAUTION: `test_sets` is for displaying only.
|
||||||
|
# If you want to skip test-clean, you have to skip
|
||||||
|
# it inside the for loop. That is, use
|
||||||
|
#
|
||||||
|
# if test_set == 'test-clean': continue
|
||||||
|
#
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||||
|
tot_num_cuts = len(test_dl.dataset.cuts)
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_idx, batch in enumerate(test_dl):
|
||||||
|
this_batch = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
batch=batch,
|
||||||
|
lexicon=lexicon,
|
||||||
|
)
|
||||||
|
results.extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_idx}, cuts processed until now is "
|
||||||
|
f"{num_cuts}/{tot_num_cuts} "
|
||||||
|
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
|
errs_filename = params.exp_dir / f"errs-{test_set}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
write_error_stats(f, test_set, results)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
86
egs/librispeech/ASR/tdnn_lstm_ctc/model.py
Normal file
86
egs/librispeech/ASR/tdnn_lstm_ctc/model.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class TdnnLstm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_features: int, num_classes: int, subsampling_factor: int = 3
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_features:
|
||||||
|
The input dimension of the model.
|
||||||
|
num_classes:
|
||||||
|
The output dimension of the model.
|
||||||
|
subsampling_factor:
|
||||||
|
It reduces the number of output frames by this factor.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.subsampling_factor = subsampling_factor
|
||||||
|
self.tdnn = nn.Sequential(
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=num_features,
|
||||||
|
out_channels=500,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.BatchNorm1d(num_features=500, affine=False),
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=500,
|
||||||
|
out_channels=500,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.BatchNorm1d(num_features=500, affine=False),
|
||||||
|
nn.Conv1d(
|
||||||
|
in_channels=500,
|
||||||
|
out_channels=500,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=self.subsampling_factor, # stride: subsampling_factor!
|
||||||
|
padding=1,
|
||||||
|
),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.BatchNorm1d(num_features=500, affine=False),
|
||||||
|
)
|
||||||
|
self.lstms = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.LSTM(input_size=500, hidden_size=500, num_layers=1)
|
||||||
|
for _ in range(5)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.lstm_bnorms = nn.ModuleList(
|
||||||
|
[nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)]
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(0.2)
|
||||||
|
self.linear = nn.Linear(in_features=500, out_features=self.num_classes)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is [N, C, T]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output tensor has shape [N, T, C]
|
||||||
|
"""
|
||||||
|
x = self.tdnn(x)
|
||||||
|
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
|
||||||
|
for lstm, bnorm in zip(self.lstms, self.lstm_bnorms):
|
||||||
|
x_new, _ = lstm(x)
|
||||||
|
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
|
||||||
|
2, 0, 1
|
||||||
|
) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
||||||
|
x_new = self.dropout(x_new)
|
||||||
|
x = x_new + x # skip connections
|
||||||
|
x = x.transpose(
|
||||||
|
1, 0
|
||||||
|
) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim
|
||||||
|
x = self.linear(x)
|
||||||
|
x = nn.functional.log_softmax(x, dim=-1)
|
||||||
|
return x
|
493
egs/librispeech/ASR/tdnn_lstm_ctc/train.py
Executable file
493
egs/librispeech/ASR/tdnn_lstm_ctc/train.py
Executable file
@ -0,0 +1,493 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# This is just at the very beginning ...
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import copyfile
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from model import TdnnLstm
|
||||||
|
from torch.nn.utils import clip_grad_value_
|
||||||
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from icefall.checkpoint import load_checkpoint
|
||||||
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
|
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||||
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import AttributeDict, encode_supervisions, setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
# TODO: add extra arguments and support DDP training.
|
||||||
|
# Currently, only single GPU training is implemented. Will add
|
||||||
|
# DDP training once single GPU training is finished.
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
is saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- exp_dir: It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
|
||||||
|
- lang_dir: It contains language related input files such as
|
||||||
|
"lexicon.txt"
|
||||||
|
|
||||||
|
- lr: It specifies the initial learning rate
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- weight_decay: The weight_decay for the optimizer.
|
||||||
|
|
||||||
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
|
- start_epoch: If it is not zero, load checkpoint `start_epoch-1`
|
||||||
|
and continue training from that checkpoint.
|
||||||
|
|
||||||
|
- num_epochs: Number of epochs to train.
|
||||||
|
|
||||||
|
- best_train_loss: Best training loss so far. It is used to select
|
||||||
|
the model that has the lowest training loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_valid_loss: Best validation loss so far. It is used to select
|
||||||
|
the model that has the lowest validation loss. It is
|
||||||
|
updated during the training.
|
||||||
|
|
||||||
|
- best_train_epoch: It is the epoch that has the best training loss.
|
||||||
|
|
||||||
|
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||||
|
|
||||||
|
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||||
|
contains number of batches trained so far across
|
||||||
|
epochs.
|
||||||
|
|
||||||
|
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||||
|
|
||||||
|
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||||
|
|
||||||
|
- beam_size: It is used in k2.ctc_loss
|
||||||
|
|
||||||
|
- reduction: It is used in k2.ctc_loss
|
||||||
|
|
||||||
|
- use_double_scores: It is used in k2.ctc_loss
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
||||||
|
"lang_dir": Path("data/lang"),
|
||||||
|
"lr": 1e-3,
|
||||||
|
"feature_dim": 80,
|
||||||
|
"weight_decay": 5e-4,
|
||||||
|
"subsampling_factor": 3,
|
||||||
|
"start_epoch": 0,
|
||||||
|
"num_epochs": 10,
|
||||||
|
"best_train_loss": float("inf"),
|
||||||
|
"best_valid_loss": float("inf"),
|
||||||
|
"best_train_epoch": -1,
|
||||||
|
"best_valid_epoch": -1,
|
||||||
|
"batch_idx_train": 0,
|
||||||
|
"log_interval": 10,
|
||||||
|
"valid_interval": 1000,
|
||||||
|
"beam_size": 10,
|
||||||
|
"reduction": "sum",
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint_if_available(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scheduler: torch.optim.lr_scheduler._LRScheduler = None,
|
||||||
|
) -> None:
|
||||||
|
"""Load checkpoint from file.
|
||||||
|
|
||||||
|
If params.start_epoch is positive, it will load the checkpoint from
|
||||||
|
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||||
|
|
||||||
|
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||||
|
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||||
|
and `best_valid_loss` in `params`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
The return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
optimizer:
|
||||||
|
The optimizer that we are using.
|
||||||
|
scheduler:
|
||||||
|
The learning rate scheduler we are using.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if params.start_epoch <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||||
|
saved_params = load_checkpoint(
|
||||||
|
filename,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
keys = [
|
||||||
|
"best_train_epoch",
|
||||||
|
"best_valid_epoch",
|
||||||
|
"batch_idx_train",
|
||||||
|
"best_train_loss",
|
||||||
|
"best_valid_loss",
|
||||||
|
]
|
||||||
|
for k in keys:
|
||||||
|
params[k] = saved_params[k]
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
||||||
|
) -> None:
|
||||||
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The training model.
|
||||||
|
"""
|
||||||
|
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||||
|
save_checkpoint_impl(
|
||||||
|
filename=filename,
|
||||||
|
model=model,
|
||||||
|
params=params,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.best_train_epoch == params.cur_epoch:
|
||||||
|
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_train_filename)
|
||||||
|
|
||||||
|
if params.best_valid_epoch == params.cur_epoch:
|
||||||
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict,
|
||||||
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
|
is_training: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute CTC loss given the model and its inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training. It is an instance of TdnnLstm in our case.
|
||||||
|
batch:
|
||||||
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
|
for the content in it.
|
||||||
|
graph_compiler:
|
||||||
|
It is used to build a decoding graph from a ctc topo and training
|
||||||
|
transcript. The training transcript is contained in the given `batch`,
|
||||||
|
while the ctc topo is built when this compiler is instantiated.
|
||||||
|
is_training:
|
||||||
|
True for training. False for validation. When it is True, this
|
||||||
|
function enables autograd during computation; when it is False, it
|
||||||
|
disables autograd.
|
||||||
|
"""
|
||||||
|
device = graph_compiler.device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
# at entry, feature is [N, T, C]
|
||||||
|
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(is_training):
|
||||||
|
nnet_output = model(feature)
|
||||||
|
# nnet_output is [N, T, C]
|
||||||
|
|
||||||
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
|
# different duration in decreasing order, required by
|
||||||
|
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
supervision_segments, texts = encode_supervisions(
|
||||||
|
supervisions, subsampling_factor=params.subsampling_factor
|
||||||
|
)
|
||||||
|
decoding_graph = graph_compiler.compile(texts)
|
||||||
|
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
nnet_output,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=params.subsampling_factor - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = k2.ctc_loss(
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
reduction=params.reduction,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
# train_frames and valid_frames are used for printing.
|
||||||
|
if is_training:
|
||||||
|
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||||
|
else:
|
||||||
|
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def compute_validation_loss(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
) -> None:
|
||||||
|
"""Run the validation process. The validation loss
|
||||||
|
is saved in `params.valid_loss`.
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
tot_loss = 0.0
|
||||||
|
tot_frames = 0.0
|
||||||
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
loss = compute_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
is_training=False,
|
||||||
|
)
|
||||||
|
assert loss.requires_grad is False
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_frames += params.valid_frames
|
||||||
|
|
||||||
|
params.valid_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if params.valid_loss < params.best_valid_loss:
|
||||||
|
params.best_valid_epoch = params.cur_epoch
|
||||||
|
params.best_valid_loss = params.valid_loss
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
graph_compiler: CtcTrainingGraphCompiler,
|
||||||
|
train_dl: torch.utils.data.DataLoader,
|
||||||
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Train the model for one epoch.
|
||||||
|
|
||||||
|
The training loss from the mean of all frames is saved in
|
||||||
|
`params.train_loss`. It runs the validation process every
|
||||||
|
`params.valid_interval` batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The model for training.
|
||||||
|
optimizer:
|
||||||
|
The optimizer we are using.
|
||||||
|
graph_compiler:
|
||||||
|
It is used to convert transcripts to FSAs.
|
||||||
|
train_dl:
|
||||||
|
Dataloader for the training dataset.
|
||||||
|
valid_dl:
|
||||||
|
Dataloader for the validation dataset.
|
||||||
|
tb_writer:
|
||||||
|
Writer to write log messages to tensorboard.
|
||||||
|
"""
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
tot_loss = 0.0 # sum of losses over all batches
|
||||||
|
tot_frames = 0.0 # sum of frames over all batches
|
||||||
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
params.batch_idx_train += 1
|
||||||
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
loss = compute_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
batch=batch,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
|
# in the batch and there is no normalization to it so far.
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
clip_grad_value_(model.parameters(), 5.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
loss_cpu = loss.detach().cpu().item()
|
||||||
|
|
||||||
|
tot_frames += params.train_frames
|
||||||
|
tot_loss += loss_cpu
|
||||||
|
tot_avg_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if batch_idx % params.log_interval == 0:
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||||
|
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||||
|
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||||
|
f"batch size: {batch_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
|
compute_validation_loss(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
)
|
||||||
|
model.train()
|
||||||
|
logging.info(
|
||||||
|
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss}, "
|
||||||
|
f"best valid loss: {params.best_valid_loss:.4f} "
|
||||||
|
f"best valid epoch: {params.best_valid_epoch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
params.train_loss = tot_loss / tot_frames
|
||||||
|
|
||||||
|
if params.train_loss < params.best_train_loss:
|
||||||
|
params.best_train_epoch = params.cur_epoch
|
||||||
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
|
logging.info("Training started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_phone_id = max(lexicon.tokens)
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||||
|
|
||||||
|
model = TdnnLstm(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
optimizer = optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=params.lr,
|
||||||
|
weight_decay=params.weight_decay,
|
||||||
|
)
|
||||||
|
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
||||||
|
|
||||||
|
load_checkpoint_if_available(
|
||||||
|
params=params, model=model, optimizer=optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
train_dl = librispeech.train_dataloaders()
|
||||||
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
|
if epoch > params.start_epoch:
|
||||||
|
logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}")
|
||||||
|
|
||||||
|
if tb_writer is not None:
|
||||||
|
tb_writer.add_scalar(
|
||||||
|
"train/lr",
|
||||||
|
scheduler.get_last_lr()[0],
|
||||||
|
params.batch_idx_train,
|
||||||
|
)
|
||||||
|
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||||
|
|
||||||
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
|
train_one_epoch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
graph_compiler=graph_compiler,
|
||||||
|
train_dl=train_dl,
|
||||||
|
valid_dl=valid_dl,
|
||||||
|
tb_writer=tb_writer,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
save_checkpoint(
|
||||||
|
params=params, model=model, optimizer=optimizer, scheduler=scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
131
icefall/checkpoint.py
Normal file
131
icefall/checkpoint.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: Union[nn.Module, DDP],
|
||||||
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
rank: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Save training information to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
The checkpoint filename.
|
||||||
|
model:
|
||||||
|
The model to be saved. We only save its `state_dict()`.
|
||||||
|
params:
|
||||||
|
User defined parameters, e.g., epoch, loss.
|
||||||
|
optimizer:
|
||||||
|
The optimizer to be saved. We only save its `state_dict()`.
|
||||||
|
scheduler:
|
||||||
|
The scheduler to be saved. We only save its `state_dict()`.
|
||||||
|
scalar:
|
||||||
|
The GradScaler to be saved. We only save its `state_dict()`.
|
||||||
|
rank:
|
||||||
|
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
if rank != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logging.info(f"Saving checkpoint to {filename}")
|
||||||
|
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
model = model.module
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||||
|
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||||
|
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if params:
|
||||||
|
for k, v in params.items():
|
||||||
|
assert k not in checkpoint
|
||||||
|
checkpoint[k] = v
|
||||||
|
|
||||||
|
torch.save(checkpoint, filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
filename: Path,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
scheduler: Optional[_LRScheduler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
TODO: document it
|
||||||
|
"""
|
||||||
|
logging.info(f"Loading checkpoint from {filename}")
|
||||||
|
checkpoint = torch.load(filename, map_location="cpu")
|
||||||
|
|
||||||
|
if next(iter(checkpoint["model"])).startswith("module."):
|
||||||
|
logging.info("Loading checkpoint saved by DDP")
|
||||||
|
|
||||||
|
dst_state_dict = model.state_dict()
|
||||||
|
src_state_dict = checkpoint["model"]
|
||||||
|
for key in dst_state_dict.keys():
|
||||||
|
src_key = "{}.{}".format("module", key)
|
||||||
|
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||||
|
assert len(src_state_dict) == 0
|
||||||
|
model.load_state_dict(dst_state_dict, strict=False)
|
||||||
|
else:
|
||||||
|
model.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
|
||||||
|
checkpoint.pop("model")
|
||||||
|
|
||||||
|
def load(name, obj):
|
||||||
|
s = checkpoint[name]
|
||||||
|
if obj and s:
|
||||||
|
obj.load_state_dict(s)
|
||||||
|
checkpoint.pop(name)
|
||||||
|
|
||||||
|
load("optimizer", optimizer)
|
||||||
|
load("scheduler", scheduler)
|
||||||
|
load("grad_scaler", scaler)
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def average_checkpoints(filenames: List[Path]) -> dict:
|
||||||
|
"""Average a list of checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
Filenames of the checkpoints to be averaged. We assume all
|
||||||
|
checkpoints are saved by :func:`save_checkpoint`.
|
||||||
|
Returns:
|
||||||
|
Return a dict (i.e., state_dict) which is the average of all
|
||||||
|
model state dicts contained in the checkpoints.
|
||||||
|
"""
|
||||||
|
n = len(filenames)
|
||||||
|
|
||||||
|
avg = torch.load(filenames[0], map_location="cpu")["model"]
|
||||||
|
for i in range(1, n):
|
||||||
|
state_dict = torch.load(filenames[i], map_location="cpu")["model"]
|
||||||
|
for k in avg:
|
||||||
|
avg[k] += state_dict[k]
|
||||||
|
|
||||||
|
for k in avg:
|
||||||
|
if avg[k].is_floating_point():
|
||||||
|
avg[k] /= n
|
||||||
|
else:
|
||||||
|
avg[k] //= n
|
||||||
|
|
||||||
|
return avg
|
0
icefall/dataset/__init__.py
Normal file
0
icefall/dataset/__init__.py
Normal file
248
icefall/dataset/asr_datamodule.py
Normal file
248
icefall/dataset/asr_datamodule.py
Normal file
@ -0,0 +1,248 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from lhotse import Fbank, FbankConfig, load_manifest
|
||||||
|
from lhotse.dataset import (
|
||||||
|
BucketingSampler,
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
K2SpeechRecognitionDataset,
|
||||||
|
SingleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.dataset.datamodule import DataModule
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class AsrDataModule(DataModule):
|
||||||
|
"""
|
||||||
|
DataModule for K2 ASR experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- augmentation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
super().add_arguments(parser)
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="ASR data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--feature-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/fbank"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=500.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-buckets",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="The number of buckets for the BucketingSampler"
|
||||||
|
"(you might want to increase it for larger datasets).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--concatenate-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, utterances (cuts) will be concatenated "
|
||||||
|
"to minimize the amount of padding.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--duration-factor",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Determines the maximum duration of a concatenated cut "
|
||||||
|
"relative to the duration of the longest cut in a batch.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--gap",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The amount of padding (in seconds) inserted between "
|
||||||
|
"concatenated cuts. This padding is filled with noise when "
|
||||||
|
"noise augmentation is used.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--on-the-fly-feats",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="When enabled, use on-the-fly cut mixing and feature "
|
||||||
|
"extraction. Will drop existing precomputed feature manifests "
|
||||||
|
"if available.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_dataloaders(self) -> DataLoader:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
cuts_train = self.train_cuts()
|
||||||
|
|
||||||
|
logging.info("About to get Musan cuts")
|
||||||
|
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||||
|
|
||||||
|
logging.info("About to create train dataset")
|
||||||
|
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||||
|
if self.args.concatenate_cuts:
|
||||||
|
logging.info(
|
||||||
|
f"Using cut concatenation with duration factor "
|
||||||
|
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||||
|
)
|
||||||
|
# Cut concatenation should be the first transform in the list,
|
||||||
|
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||||
|
# different utterances.
|
||||||
|
transforms = [
|
||||||
|
CutConcatenate(
|
||||||
|
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||||
|
)
|
||||||
|
] + transforms
|
||||||
|
|
||||||
|
input_transforms = [
|
||||||
|
SpecAugment(
|
||||||
|
num_frame_masks=2,
|
||||||
|
features_mask_size=27,
|
||||||
|
num_feature_masks=2,
|
||||||
|
frames_mask_size=100,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
train = K2SpeechRecognitionDataset(
|
||||||
|
cuts_train,
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_transforms=input_transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
# NOTE: the PerturbSpeed transform should be added only if we
|
||||||
|
# remove it from data prep stage.
|
||||||
|
# Add on-the-fly speed perturbation; since originally it would
|
||||||
|
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||||
|
# 3x more epochs.
|
||||||
|
# Speed perturbation probably should come first before
|
||||||
|
# concatenation, but in principle the transforms order doesn't have
|
||||||
|
# to be strict (e.g. could be randomized)
|
||||||
|
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||||
|
# Drop feats to be on the safe side.
|
||||||
|
cuts_train = cuts_train.drop_features()
|
||||||
|
train = K2SpeechRecognitionDataset(
|
||||||
|
cuts=cuts_train,
|
||||||
|
cut_transforms=transforms,
|
||||||
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
|
input_transforms=input_transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.bucketing_sampler:
|
||||||
|
logging.info("Using BucketingSampler.")
|
||||||
|
train_sampler = BucketingSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=True,
|
||||||
|
num_buckets=self.args.num_buckets,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info("Using SingleCutSampler.")
|
||||||
|
train_sampler = SingleCutSampler(
|
||||||
|
cuts_train,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
logging.info("About to create train dataloader")
|
||||||
|
train_dl = DataLoader(
|
||||||
|
train,
|
||||||
|
sampler=train_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
return train_dl
|
||||||
|
|
||||||
|
def valid_dataloaders(self) -> DataLoader:
|
||||||
|
logging.info("About to get dev cuts")
|
||||||
|
cuts_valid = self.valid_cuts()
|
||||||
|
|
||||||
|
logging.info("About to create dev dataset")
|
||||||
|
if self.args.on_the_fly_feats:
|
||||||
|
cuts_valid = cuts_valid.drop_features()
|
||||||
|
validate = K2SpeechRecognitionDataset(
|
||||||
|
cuts_valid.drop_features(),
|
||||||
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
validate = K2SpeechRecognitionDataset(cuts_valid)
|
||||||
|
valid_sampler = SingleCutSampler(
|
||||||
|
cuts_valid,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
)
|
||||||
|
logging.info("About to create dev dataloader")
|
||||||
|
valid_dl = DataLoader(
|
||||||
|
validate,
|
||||||
|
sampler=valid_sampler,
|
||||||
|
batch_size=None,
|
||||||
|
num_workers=2,
|
||||||
|
persistent_workers=True,
|
||||||
|
)
|
||||||
|
return valid_dl
|
||||||
|
|
||||||
|
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
|
cuts = self.test_cuts()
|
||||||
|
is_list = isinstance(cuts, list)
|
||||||
|
test_loaders = []
|
||||||
|
if not is_list:
|
||||||
|
cuts = [cuts]
|
||||||
|
|
||||||
|
for cuts_test in cuts:
|
||||||
|
logging.debug("About to create test dataset")
|
||||||
|
test = K2SpeechRecognitionDataset(
|
||||||
|
cuts_test,
|
||||||
|
input_strategy=OnTheFlyFeatures(
|
||||||
|
Fbank(FbankConfig(num_mel_bins=80))
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sampler = SingleCutSampler(
|
||||||
|
cuts_test, max_duration=self.args.max_duration
|
||||||
|
)
|
||||||
|
logging.debug("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test, batch_size=None, sampler=sampler, num_workers=1
|
||||||
|
)
|
||||||
|
test_loaders.append(test_dl)
|
||||||
|
|
||||||
|
if is_list:
|
||||||
|
return test_loaders
|
||||||
|
else:
|
||||||
|
return test_loaders[0]
|
43
icefall/dataset/datamodule.py
Normal file
43
icefall/dataset/datamodule.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
from lhotse import CutSet
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class DataModule:
|
||||||
|
"""
|
||||||
|
Contains dataset-related code. It is intended to read/construct Lhotse cuts,
|
||||||
|
and create Dataset/Sampler/DataLoader out of them.
|
||||||
|
|
||||||
|
There is a separate method to create each of train/valid/test DataLoader.
|
||||||
|
In principle, there might be multiple DataLoaders for each of
|
||||||
|
train/valid/test
|
||||||
|
(e.g. when a corpus has multiple test sets).
|
||||||
|
The API of this class allows to return lists of CutSets/DataLoaders.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train_cuts(self) -> Union[CutSet, List[CutSet]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def valid_cuts(self) -> Union[CutSet, List[CutSet]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def test_cuts(self) -> Union[CutSet, List[CutSet]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def train_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def valid_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
|
raise NotImplementedError()
|
68
icefall/dataset/librispeech.py
Normal file
68
icefall/dataset/librispeech.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
|
from icefall.dataset.asr_datamodule import AsrDataModule
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class LibriSpeechAsrDataModule(AsrDataModule):
|
||||||
|
"""
|
||||||
|
LibriSpeech ASR data module. Can be used for 100h subset
|
||||||
|
(``--full-libri false``) or full 960h set.
|
||||||
|
The train and valid cuts for standard Libri splits are
|
||||||
|
concatenated into a single CutSet/DataLoader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
super().add_arguments(parser)
|
||||||
|
group = parser.add_argument_group(title="LibriSpeech specific options")
|
||||||
|
group.add_argument(
|
||||||
|
"--full-libri",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, use 960h LibriSpeech.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def train_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get train cuts")
|
||||||
|
cuts_train = load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
||||||
|
)
|
||||||
|
if self.args.full_libri:
|
||||||
|
cuts_train = (
|
||||||
|
cuts_train
|
||||||
|
+ load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
||||||
|
)
|
||||||
|
+ load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cuts_train
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def valid_cuts(self) -> CutSet:
|
||||||
|
logging.info("About to get dev cuts")
|
||||||
|
cuts_valid = load_manifest(
|
||||||
|
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
||||||
|
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
||||||
|
return cuts_valid
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def test_cuts(self) -> List[CutSet]:
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
cuts = []
|
||||||
|
for test_set in test_sets:
|
||||||
|
logging.debug("About to get test cuts")
|
||||||
|
cuts.append(
|
||||||
|
load_manifest(
|
||||||
|
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return cuts
|
109
icefall/graph_compiler.py
Normal file
109
icefall/graph_compiler.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
class CtcTrainingGraphCompiler(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
device: torch.device,
|
||||||
|
oov: str = "<UNK>",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lexicon:
|
||||||
|
It is built from `data/lang/lexicon.txt`.
|
||||||
|
device:
|
||||||
|
The device to use for operations compiling transcripts to FSAs.
|
||||||
|
oov:
|
||||||
|
Out of vocabulary word. When a word in the transcript
|
||||||
|
does not exist in the lexicon, it is replaced with `oov`.
|
||||||
|
"""
|
||||||
|
L_inv = lexicon.L_inv.to(device)
|
||||||
|
assert L_inv.requires_grad is False
|
||||||
|
|
||||||
|
assert oov in lexicon.words
|
||||||
|
|
||||||
|
self.L_inv = k2.arc_sort(L_inv)
|
||||||
|
self.oov_id = lexicon.words[oov]
|
||||||
|
self.words = lexicon.words
|
||||||
|
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
ctc_topo = k2.ctc_topo(max_token_id, modified=False)
|
||||||
|
|
||||||
|
self.ctc_topo = ctc_topo.to(device)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def compile(self, texts: List[str]) -> k2.Fsa:
|
||||||
|
"""Build decoding graphs by composing ctc_topo with
|
||||||
|
given transcripts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of strings. Each string contains a sentence for an utterance.
|
||||||
|
A sentence consists of spaces separated words. An example `texts`
|
||||||
|
looks like:
|
||||||
|
|
||||||
|
['hello icefall', 'CTC training with k2']
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An FsaVec, the composition result of `self.ctc_topo` and the
|
||||||
|
transcript FSA.
|
||||||
|
"""
|
||||||
|
transcript_fsa = self.convert_transcript_to_fsa(texts)
|
||||||
|
|
||||||
|
# NOTE: k2.compose runs on CUDA only when treat_epsilons_specially
|
||||||
|
# is False, so we add epsilon self-loops here
|
||||||
|
fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
||||||
|
transcript_fsa
|
||||||
|
)
|
||||||
|
|
||||||
|
fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops)
|
||||||
|
|
||||||
|
decoding_graph = k2.compose(
|
||||||
|
self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert decoding_graph.requires_grad is False
|
||||||
|
|
||||||
|
return decoding_graph
|
||||||
|
|
||||||
|
def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||||
|
"""Convert a list of transcript texts to an FsaVec.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts:
|
||||||
|
A list of strings. Each string contains a sentence for an utterance.
|
||||||
|
A sentence consists of spaces separated words. An example `texts`
|
||||||
|
looks like:
|
||||||
|
|
||||||
|
['hello icefall', 'CTC training with k2']
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return an FsaVec, whose `shape[0]` equals to `len(texts)`.
|
||||||
|
"""
|
||||||
|
word_ids_list = []
|
||||||
|
for text in texts:
|
||||||
|
word_ids = []
|
||||||
|
for word in text.split(" "):
|
||||||
|
if word in self.words:
|
||||||
|
word_ids.append(self.words[word])
|
||||||
|
else:
|
||||||
|
word_ids.append(self.oov_id)
|
||||||
|
word_ids_list.append(word_ids)
|
||||||
|
|
||||||
|
word_fsa = k2.linear_fsa(word_ids_list, self.device)
|
||||||
|
|
||||||
|
word_fsa_with_self_loops = k2.add_epsilon_self_loops(word_fsa)
|
||||||
|
|
||||||
|
fsa = k2.intersect(
|
||||||
|
self.L_inv, word_fsa_with_self_loops, treat_epsilons_specially=False
|
||||||
|
)
|
||||||
|
# fsa has word ID as labels and token ID as aux_labels, so
|
||||||
|
# we need to invert it
|
||||||
|
ans_fsa = fsa.invert_()
|
||||||
|
return k2.arc_sort(ans_fsa)
|
66
icefall/lexicon.py
Normal file
66
icefall/lexicon.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Lexicon(object):
|
||||||
|
"""Phone based lexicon.
|
||||||
|
|
||||||
|
TODO: Add BpeLexicon for BPE models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$")
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
Path to the lang director. It is expected to contain the following
|
||||||
|
files:
|
||||||
|
- phones.txt
|
||||||
|
- words.txt
|
||||||
|
- L.pt
|
||||||
|
The above files are produced by the script `prepare.sh`. You
|
||||||
|
should have run that before running the training code.
|
||||||
|
disambig_pattern:
|
||||||
|
It contains the pattern for disambiguation symbols.
|
||||||
|
"""
|
||||||
|
lang_dir = Path(lang_dir)
|
||||||
|
self.phones = k2.SymbolTable.from_file(lang_dir / "phones.txt")
|
||||||
|
self.words = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
|
|
||||||
|
if (lang_dir / "Linv.pt").exists():
|
||||||
|
logging.info("Loading pre-compiled Linv.pt")
|
||||||
|
L_inv = k2.Fsa.from_dict(torch.load(lang_dir / "Linv.pt"))
|
||||||
|
else:
|
||||||
|
logging.info("Converting L.pt to Linv.pt")
|
||||||
|
L = k2.Fsa.from_dict(torch.load(lang_dir / "L.pt"))
|
||||||
|
L_inv = k2.arc_sort(L.invert())
|
||||||
|
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
||||||
|
|
||||||
|
# We save L_inv instead of L because it will be used to intersect with
|
||||||
|
# transcript, both of whose labels are word IDs.
|
||||||
|
self.L_inv = L_inv
|
||||||
|
self.disambig_pattern = disambig_pattern
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens(self) -> List[int]:
|
||||||
|
"""Return a list of phone IDs excluding those from
|
||||||
|
disambiguation symbols.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
0 is not a phone ID so it is excluded from the return value.
|
||||||
|
"""
|
||||||
|
symbols = self.phones.symbols
|
||||||
|
ans = []
|
||||||
|
for s in symbols:
|
||||||
|
if not self.disambig_pattern.match(s):
|
||||||
|
ans.append(self.phones[s])
|
||||||
|
if 0 in ans:
|
||||||
|
ans.remove(0)
|
||||||
|
ans.sort()
|
||||||
|
return ans
|
298
icefall/utils.py
298
icefall/utils.py
@ -1,5 +1,20 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, TextIO, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import k2.ragged as k2r
|
||||||
|
import kaldialign
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
Pathlike = Union[str, Path]
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -32,3 +47,286 @@ def get_executor():
|
|||||||
# No need to return anything - compute_and_store_features
|
# No need to return anything - compute_and_store_features
|
||||||
# will just instantiate the pool itself.
|
# will just instantiate the pool itself.
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(v):
|
||||||
|
"""Used in argparse.ArgumentParser.add_argument to indicate
|
||||||
|
that a type is a bool type and user can enter
|
||||||
|
|
||||||
|
- yes, true, t, y, 1, to represent True
|
||||||
|
- no, false, f, n, 0, to represent False
|
||||||
|
|
||||||
|
See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa
|
||||||
|
"""
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||||
|
return True
|
||||||
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(
|
||||||
|
log_filename: Pathlike, log_level: str = "info", use_console: bool = True
|
||||||
|
) -> None:
|
||||||
|
"""Setup log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_filename:
|
||||||
|
The filename to save the log.
|
||||||
|
log_level:
|
||||||
|
The log level to use, e.g., "debug", "info", "warning", "error",
|
||||||
|
"critical"
|
||||||
|
"""
|
||||||
|
now = datetime.now()
|
||||||
|
date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
|
||||||
|
|
||||||
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
rank = dist.get_rank()
|
||||||
|
formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa
|
||||||
|
log_filename = f"{log_filename}-{date_time}-{rank}"
|
||||||
|
else:
|
||||||
|
formatter = (
|
||||||
|
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
)
|
||||||
|
log_filename = f"{log_filename}-{date_time}"
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(log_filename), exist_ok=True)
|
||||||
|
|
||||||
|
level = logging.ERROR
|
||||||
|
if log_level == "debug":
|
||||||
|
level = logging.DEBUG
|
||||||
|
elif log_level == "info":
|
||||||
|
level = logging.INFO
|
||||||
|
elif log_level == "warning":
|
||||||
|
level = logging.WARNING
|
||||||
|
elif log_level == "critical":
|
||||||
|
level = logging.CRITICAL
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
filename=log_filename, format=formatter, level=level, filemode="w"
|
||||||
|
)
|
||||||
|
if use_console:
|
||||||
|
console = logging.StreamHandler()
|
||||||
|
console.setLevel(level)
|
||||||
|
console.setFormatter(logging.Formatter(formatter))
|
||||||
|
logging.getLogger("").addHandler(console)
|
||||||
|
|
||||||
|
|
||||||
|
def get_env_info():
|
||||||
|
"""
|
||||||
|
TODO:
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"k2-git-sha1": None,
|
||||||
|
"k2-version": None,
|
||||||
|
"lhotse-version": None,
|
||||||
|
"torch-version": None,
|
||||||
|
"icefall-sha1": None,
|
||||||
|
"icefall-version": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# See
|
||||||
|
# https://stackoverflow.com/questions/4984647/accessing-dict-keys-like-an-attribute # noqa
|
||||||
|
class AttributeDict(dict):
|
||||||
|
__slots__ = ()
|
||||||
|
__getattr__ = dict.__getitem__
|
||||||
|
__setattr__ = dict.__setitem__
|
||||||
|
|
||||||
|
|
||||||
|
def encode_supervisions(
|
||||||
|
supervisions: Dict[str, torch.Tensor], subsampling_factor: int
|
||||||
|
) -> Tuple[torch.Tensor, List[str]]:
|
||||||
|
"""
|
||||||
|
Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor,
|
||||||
|
and a list of transcription strings.
|
||||||
|
|
||||||
|
The supervision tensor has shape ``(batch_size, 3)``.
|
||||||
|
Its second dimension contains information about sequence index [0],
|
||||||
|
start frames [1] and num frames [2].
|
||||||
|
|
||||||
|
The batch items might become re-ordered during this operation -- the
|
||||||
|
returned tensor and list of strings are guaranteed to be consistent with
|
||||||
|
each other.
|
||||||
|
"""
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
supervisions["start_frame"] // subsampling_factor,
|
||||||
|
supervisions["num_frames"] // subsampling_factor,
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
||||||
|
supervision_segments = supervision_segments[indices]
|
||||||
|
texts = supervisions["text"]
|
||||||
|
texts = [texts[idx] for idx in indices]
|
||||||
|
|
||||||
|
return supervision_segments, texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||||
|
"""Extract the texts from the best-path FSAs.
|
||||||
|
Args:
|
||||||
|
best_paths:
|
||||||
|
A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e.
|
||||||
|
containing multiple FSAs, which is expected to be the result
|
||||||
|
of k2.shortest_path (otherwise the returned values won't
|
||||||
|
be meaningful).
|
||||||
|
Returns:
|
||||||
|
Returns a list of lists of int, containing the label sequences we
|
||||||
|
decoded.
|
||||||
|
"""
|
||||||
|
if isinstance(best_paths.aux_labels, k2.RaggedInt):
|
||||||
|
# remove 0's and -1's.
|
||||||
|
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
|
||||||
|
aux_shape = k2r.compose_ragged_shapes(
|
||||||
|
best_paths.arcs.shape(), aux_labels.shape()
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove the states and arcs axes.
|
||||||
|
aux_shape = k2r.remove_axis(aux_shape, 1)
|
||||||
|
aux_shape = k2r.remove_axis(aux_shape, 1)
|
||||||
|
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
|
||||||
|
else:
|
||||||
|
# remove axis corresponding to states.
|
||||||
|
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
|
||||||
|
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
|
||||||
|
# remove 0's and -1's.
|
||||||
|
aux_labels = k2r.remove_values_leq(aux_labels, 0)
|
||||||
|
|
||||||
|
assert aux_labels.num_axes() == 2
|
||||||
|
return k2r.to_list(aux_labels)
|
||||||
|
|
||||||
|
|
||||||
|
def write_error_stats(
|
||||||
|
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
|
||||||
|
) -> float:
|
||||||
|
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||||
|
ins: Dict[str, int] = defaultdict(int)
|
||||||
|
dels: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# `words` stores counts per word, as follows:
|
||||||
|
# corr, ref_sub, hyp_sub, ins, dels
|
||||||
|
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||||
|
num_corr = 0
|
||||||
|
ERR = "*"
|
||||||
|
for ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
for ref_word, hyp_word in ali:
|
||||||
|
if ref_word == ERR:
|
||||||
|
ins[hyp_word] += 1
|
||||||
|
words[hyp_word][3] += 1
|
||||||
|
elif hyp_word == ERR:
|
||||||
|
dels[ref_word] += 1
|
||||||
|
words[ref_word][4] += 1
|
||||||
|
elif hyp_word != ref_word:
|
||||||
|
subs[(ref_word, hyp_word)] += 1
|
||||||
|
words[ref_word][1] += 1
|
||||||
|
words[hyp_word][2] += 1
|
||||||
|
else:
|
||||||
|
words[ref_word][0] += 1
|
||||||
|
num_corr += 1
|
||||||
|
ref_len = sum([len(r) for r, _ in results])
|
||||||
|
sub_errs = sum(subs.values())
|
||||||
|
ins_errs = sum(ins.values())
|
||||||
|
del_errs = sum(dels.values())
|
||||||
|
tot_errs = sub_errs + ins_errs + del_errs
|
||||||
|
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||||
|
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||||
|
f"{del_errs} del, {sub_errs} sub ]"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"%WER = {tot_err_rate}", file=f)
|
||||||
|
print(
|
||||||
|
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||||
|
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||||
|
f"words ({num_corr} correct)",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||||
|
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||||
|
for ref, hyp in results:
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
|
combine_successive_errors = True
|
||||||
|
if combine_successive_errors:
|
||||||
|
ali = [[[x], [y]] for x, y in ali]
|
||||||
|
for i in range(len(ali) - 1):
|
||||||
|
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||||
|
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||||
|
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||||
|
ali[i] = [[], []]
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
list(filter(lambda a: a != ERR, x)),
|
||||||
|
list(filter(lambda a: a != ERR, y)),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
ali = list(filter(lambda x: x != [[], []], ali))
|
||||||
|
ali = [
|
||||||
|
[
|
||||||
|
ERR if x == [] else " ".join(x),
|
||||||
|
ERR if y == [] else " ".join(y),
|
||||||
|
]
|
||||||
|
for x, y in ali
|
||||||
|
]
|
||||||
|
|
||||||
|
print(
|
||||||
|
" ".join(
|
||||||
|
(
|
||||||
|
ref_word
|
||||||
|
if ref_word == hyp_word
|
||||||
|
else f"({ref_word}->{hyp_word})"
|
||||||
|
for ref_word, hyp_word in ali
|
||||||
|
)
|
||||||
|
),
|
||||||
|
file=f,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
|
for count, (ref, hyp) in sorted(
|
||||||
|
[(v, k) for k, v in subs.items()], reverse=True
|
||||||
|
):
|
||||||
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("DELETIONS: count ref", file=f)
|
||||||
|
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||||
|
print(f"{count} {ref}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print("INSERTIONS: count hyp", file=f)
|
||||||
|
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||||
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
|
print("", file=f)
|
||||||
|
print(
|
||||||
|
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||||
|
)
|
||||||
|
for _, word, counts in sorted(
|
||||||
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
|
):
|
||||||
|
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||||
|
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||||
|
ref_count = corr + ref_sub + dels
|
||||||
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
|
return float(tot_err_rate)
|
||||||
|
50
test/test_checkpoint.py
Normal file
50
test/test_checkpoint.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
save_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def checkpoints1(tmp_path):
|
||||||
|
f = tmp_path / "f.pt"
|
||||||
|
m = nn.Module()
|
||||||
|
m.p1 = nn.Parameter(torch.tensor([10.0, 20.0]), requires_grad=False)
|
||||||
|
m.register_buffer("p2", torch.tensor([10, 100]))
|
||||||
|
|
||||||
|
params = {"a": 10, "b": 20}
|
||||||
|
save_checkpoint(f, m, params=params)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def checkpoints2(tmp_path):
|
||||||
|
f = tmp_path / "f2.pt"
|
||||||
|
m = nn.Module()
|
||||||
|
m.p1 = nn.Parameter(torch.Tensor([50, 30.0]))
|
||||||
|
m.register_buffer("p2", torch.tensor([1, 3]))
|
||||||
|
params = {"a": 100, "b": 200}
|
||||||
|
|
||||||
|
save_checkpoint(f, m, params=params)
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_checkpoints(checkpoints1):
|
||||||
|
m = nn.Module()
|
||||||
|
m.p1 = nn.Parameter(torch.Tensor([0, 0.0]))
|
||||||
|
m.p2 = nn.Parameter(torch.Tensor([0, 0]))
|
||||||
|
params = load_checkpoint(checkpoints1, m)
|
||||||
|
assert torch.allclose(m.p1, torch.Tensor([10.0, 20]))
|
||||||
|
assert params == {"a": 10, "b": 20}
|
||||||
|
|
||||||
|
|
||||||
|
def test_average_checkpoints(checkpoints1, checkpoints2):
|
||||||
|
state_dict = average_checkpoints([checkpoints1, checkpoints2])
|
||||||
|
assert torch.allclose(state_dict["p1"], torch.Tensor([30, 25.0]))
|
||||||
|
assert torch.allclose(state_dict["p2"], torch.tensor([5, 51]))
|
160
test/test_graph_compiler.py
Normal file
160
test/test_graph_compiler.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import get_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def lexicon():
|
||||||
|
"""
|
||||||
|
We use the following test data:
|
||||||
|
|
||||||
|
lexicon.txt
|
||||||
|
|
||||||
|
foo f o o
|
||||||
|
bar b a r
|
||||||
|
baz b a z
|
||||||
|
<UNK> SPN
|
||||||
|
|
||||||
|
phones.txt
|
||||||
|
|
||||||
|
<eps> 0
|
||||||
|
a 1
|
||||||
|
b 2
|
||||||
|
f 3
|
||||||
|
o 4
|
||||||
|
r 5
|
||||||
|
z 6
|
||||||
|
SPN 7
|
||||||
|
|
||||||
|
words.txt:
|
||||||
|
|
||||||
|
<eps> 0
|
||||||
|
foo 1
|
||||||
|
bar 2
|
||||||
|
baz 3
|
||||||
|
<UNK> 4
|
||||||
|
"""
|
||||||
|
L = k2.Fsa.from_str(
|
||||||
|
"""
|
||||||
|
0 0 7 4 0
|
||||||
|
0 7 -1 -1 0
|
||||||
|
0 1 3 1 0
|
||||||
|
0 3 2 2 0
|
||||||
|
0 5 2 3 0
|
||||||
|
1 2 4 0 0
|
||||||
|
2 0 4 0 0
|
||||||
|
3 4 1 0 0
|
||||||
|
4 0 5 0 0
|
||||||
|
5 6 1 0 0
|
||||||
|
6 0 6 0 0
|
||||||
|
7
|
||||||
|
""",
|
||||||
|
num_aux_labels=1,
|
||||||
|
)
|
||||||
|
L.labels_sym = k2.SymbolTable.from_str(
|
||||||
|
"""
|
||||||
|
a 1
|
||||||
|
b 2
|
||||||
|
f 3
|
||||||
|
o 4
|
||||||
|
r 5
|
||||||
|
z 6
|
||||||
|
SPN 7
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
L.aux_labels_sym = k2.SymbolTable.from_str(
|
||||||
|
"""
|
||||||
|
foo 1
|
||||||
|
bar 2
|
||||||
|
baz 3
|
||||||
|
<UNK> 4
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
ans = Lexicon.__new__(Lexicon)
|
||||||
|
ans.phones = L.labels_sym
|
||||||
|
ans.words = L.aux_labels_sym
|
||||||
|
ans.L_inv = k2.arc_sort(L.invert_())
|
||||||
|
ans.disambig_pattern = re.compile(r"^#\d+$")
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def compiler(lexicon):
|
||||||
|
return CtcTrainingGraphCompiler(lexicon, device=torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestCtcTrainingGraphCompiler(object):
|
||||||
|
@staticmethod
|
||||||
|
def test_convert_transcript_to_fsa(compiler, lexicon):
|
||||||
|
texts = ["bar foo", "baz ok"]
|
||||||
|
fsa = compiler.convert_transcript_to_fsa(texts)
|
||||||
|
labels0 = fsa[0].labels[:-1].tolist()
|
||||||
|
aux_labels0 = fsa[0].aux_labels[:-1]
|
||||||
|
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
|
||||||
|
|
||||||
|
labels1 = fsa[1].labels[:-1].tolist()
|
||||||
|
aux_labels1 = fsa[1].aux_labels[:-1]
|
||||||
|
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
|
||||||
|
|
||||||
|
labels0 = [lexicon.phones[i] for i in labels0]
|
||||||
|
labels1 = [lexicon.phones[i] for i in labels1]
|
||||||
|
|
||||||
|
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
|
||||||
|
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
|
||||||
|
|
||||||
|
assert labels0 == ["b", "a", "r", "f", "o", "o"]
|
||||||
|
assert aux_labels0 == ["bar", "foo"]
|
||||||
|
|
||||||
|
assert labels1 == ["b", "a", "z", "SPN"]
|
||||||
|
assert aux_labels1 == ["baz", "<UNK>"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test_compile(compiler, lexicon):
|
||||||
|
texts = ["bar foo", "baz ok"]
|
||||||
|
decoding_graph = compiler.compile(texts)
|
||||||
|
input1 = ["b", "b", "<blk>", "<blk>", "a", "a", "r", "<blk>", "<blk>"]
|
||||||
|
input1 += ["f", "f", "<blk>", "<blk>", "o", "o", "<blk>", "o", "o"]
|
||||||
|
|
||||||
|
input2 = ["b", "b", "a", "a", "a", "<blk>", "<blk>", "z", "z"]
|
||||||
|
input2 += ["<blk>", "<blk>", "SPN", "SPN", "<blk>", "<blk>"]
|
||||||
|
|
||||||
|
lexicon.phones._id2sym[0] == "<blk>"
|
||||||
|
lexicon.phones._sym2id["<blk>"] = 0
|
||||||
|
|
||||||
|
input1 = [lexicon.phones[i] for i in input1]
|
||||||
|
input2 = [lexicon.phones[i] for i in input2]
|
||||||
|
|
||||||
|
fsa1 = k2.linear_fsa(input1)
|
||||||
|
fsa2 = k2.linear_fsa(input2)
|
||||||
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
||||||
|
|
||||||
|
decoding_graph = k2.arc_sort(decoding_graph)
|
||||||
|
lattice = k2.intersect(
|
||||||
|
decoding_graph, fsas, treat_epsilons_specially=False
|
||||||
|
)
|
||||||
|
lattice = k2.connect(lattice)
|
||||||
|
|
||||||
|
aux_labels0 = lattice[0].aux_labels[:-1]
|
||||||
|
aux_labels0 = aux_labels0[aux_labels0 != 0].tolist()
|
||||||
|
aux_labels0 = [lexicon.words[i] for i in aux_labels0]
|
||||||
|
assert aux_labels0 == ["bar", "foo"]
|
||||||
|
|
||||||
|
aux_labels1 = lattice[1].aux_labels[:-1]
|
||||||
|
aux_labels1 = aux_labels1[aux_labels1 != 0].tolist()
|
||||||
|
aux_labels1 = [lexicon.words[i] for i in aux_labels1]
|
||||||
|
assert aux_labels1 == ["baz", "<UNK>"]
|
||||||
|
|
||||||
|
texts = get_texts(lattice)
|
||||||
|
texts = [[lexicon.words[i] for i in words] for words in texts]
|
||||||
|
assert texts == [["bar", "foo"], ["baz", "<UNK>"]]
|
93
test/test_utils.py
Normal file
93
test/test_utils.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import k2
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict, encode_supervisions, get_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sup():
|
||||||
|
sequence_idx = torch.tensor([0, 1, 2])
|
||||||
|
start_frame = torch.tensor([1, 3, 9])
|
||||||
|
num_frames = torch.tensor([20, 30, 10])
|
||||||
|
text = ["one", "two", "three"]
|
||||||
|
return {
|
||||||
|
"sequence_idx": sequence_idx,
|
||||||
|
"start_frame": start_frame,
|
||||||
|
"num_frames": num_frames,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_encode_supervisions(sup):
|
||||||
|
supervision_segments, texts = encode_supervisions(sup, subsampling_factor=4)
|
||||||
|
assert torch.all(
|
||||||
|
torch.eq(
|
||||||
|
supervision_segments,
|
||||||
|
torch.tensor(
|
||||||
|
[[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert texts == ["two", "one", "three"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_texts_ragged():
|
||||||
|
fsa1 = k2.Fsa.from_str(
|
||||||
|
"""
|
||||||
|
0 1 1 10
|
||||||
|
1 2 2 20
|
||||||
|
2 3 3 30
|
||||||
|
3 4 -1 0
|
||||||
|
4
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
fsa1.aux_labels = k2.RaggedInt("[ [1 3 0 2] [] [4 0 1] [-1]]")
|
||||||
|
|
||||||
|
fsa2 = k2.Fsa.from_str(
|
||||||
|
"""
|
||||||
|
0 1 1 1
|
||||||
|
1 2 2 2
|
||||||
|
2 3 -1 0
|
||||||
|
3
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
fsa2.aux_labels = k2.RaggedInt("[[3 0 5 0 8] [0 9 7 0] [-1]]")
|
||||||
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
||||||
|
texts = get_texts(fsas)
|
||||||
|
assert texts == [[1, 3, 2, 4, 1], [3, 5, 8, 9, 7]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_texts_regular():
|
||||||
|
fsa1 = k2.Fsa.from_str(
|
||||||
|
"""
|
||||||
|
0 1 1 3 10
|
||||||
|
1 2 2 0 20
|
||||||
|
2 3 3 2 30
|
||||||
|
3 4 -1 -1 0
|
||||||
|
4
|
||||||
|
""",
|
||||||
|
num_aux_labels=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
fsa2 = k2.Fsa.from_str(
|
||||||
|
"""
|
||||||
|
0 1 1 10 1
|
||||||
|
1 2 2 5 2
|
||||||
|
2 3 -1 -1 0
|
||||||
|
3
|
||||||
|
""",
|
||||||
|
num_aux_labels=1,
|
||||||
|
)
|
||||||
|
fsas = k2.Fsa.from_fsas([fsa1, fsa2])
|
||||||
|
texts = get_texts(fsas)
|
||||||
|
assert texts == [[3, 2], [10, 5]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_attribute_dict():
|
||||||
|
s = AttributeDict({"a": 10, "b": 20})
|
||||||
|
assert s.a == 10
|
||||||
|
assert s["b"] == 20
|
||||||
|
s.c = 100
|
||||||
|
assert s["c"] == 100
|
Loading…
x
Reference in New Issue
Block a user