Refactoring (#4)

* Fix an error in TDNN-LSTM training.

* WIP: Refactoring

* Refactor transformer.py

* Remove unused code.

* Minor fixes.
This commit is contained in:
Fangjun Kuang 2021-08-04 14:53:02 +08:00 committed by GitHub
parent cf8d76293d
commit 5a0b9bcb23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 968 additions and 648 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ path.sh
exp
exp*/
*.pt
download/

View File

@ -84,20 +84,26 @@ class Conformer(Transformer):
# and throws an error without this change.
self.after_norm = identity
def encode(
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x: Tensor of dimension (batch_size, num_features, input_length).
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
x:
The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)

View File

@ -15,6 +15,7 @@ import torch
import torch.nn as nn
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
@ -62,7 +63,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"nhead": 8,
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "1best",
"method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 100,
@ -100,6 +101,8 @@ def decode_one_batch(
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -133,6 +136,10 @@ def decode_one_batch(
for the format of the `batch`.
lexicon:
It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -147,15 +154,10 @@ def decode_one_batch(
feature = feature.to(device)
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1)
# now nnet_output is [N, T, C]
# nnet_output is [N, T, C]
supervision_segments = torch.stack(
(
@ -227,6 +229,8 @@ def decode_one_batch(
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -245,6 +249,8 @@ def decode_dataset(
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
@ -260,6 +266,10 @@ def decode_dataset(
The decoding graph.
lexicon:
It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -287,6 +297,8 @@ def decode_dataset(
batch=batch,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
for lm_scale, hyps in hyps_dict.items():
@ -314,20 +326,31 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@ -367,15 +390,22 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt"))
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
# HLG = k2.ctc_topo(4999).to(device)
if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
@ -461,6 +491,8 @@ def main():
HLG=HLG,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(
@ -470,5 +502,8 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
def test_conv2d_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = Conv2dSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim
def test_vgg_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = VggSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
import torch
from transformer import (
Transformer,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
supervisions = {
"sequence_idx": torch.tensor([0, 1, 2]),
"start_frame": torch.tensor([0, 0, 0]),
"num_frames": torch.tensor([18, 7, 13]),
}
max_len = ((18 - 1) // 2 - 1) // 2
mask = encoder_padding_mask(max_len, supervisions)
expected_mask = torch.tensor(
[
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_transformer():
num_features = 40
num_classes = 87
model = Transformer(num_features=num_features, num_classes=num_classes)
N = 31
for T in range(7, 30):
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s)
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_decoder_padding_mask():
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor(
[
[False, False, True],
[False, True, True],
[False, False, False],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_add_sos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_sos(x, sos_id=0)
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
assert y == expected_y
def test_add_eos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_eos(x, eos_id=0)
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
assert y == expected_y

View File

@ -125,7 +125,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 0.0,
"subsampling_factor": 4,
@ -275,15 +275,13 @@ def compute_loss(
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
# nnet_output is [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
@ -536,6 +534,22 @@ def train_one_epoch(
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames
@ -675,5 +689,8 @@ def main():
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang or data/lang/bpe.
The language directory, e.g., data/lang_phone or data/lang_bpe.
Return:
An FSA representing HLG.
@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt")
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
return HLG
def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return
logging.info("Compiling phone based HLG")
HLG = compile_HLG("data/lang")
logging.info("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return
logging.info("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
logging.info("Saving HLG_bpe.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
def main():
phone_based_HLG()
bpe_based_HLG()
for d in ["data/lang_phone", "data/lang_bpe"]:
d = Path(d)
logging.info(f"Processing {d}")
if (d / "HLG.pt").is_file():
logging.info(f"{d}/HLG.pt already exists - skipping")
continue
HLG = compile_HLG(d)
logging.info(f"Saving HLG.pt to {d}")
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
if __name__ == "__main__":

View File

@ -1,19 +1,28 @@
#!/usr/bin/env python3
"""
This file computes fbank features of the librispeech dataset.
Its looks for manifests in the directory data/manifests
and generated fbank features are saved in data/fbank.
This file computes fbank features of the LibriSpeech dataset.
Its looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_librispeech():
src_dir = Path("data/manifests")
@ -40,12 +49,11 @@ def compute_fbank_librispeech():
with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file():
print(f"{partition} already exists - skipping.")
logging.info(f"{partition} already exists - skipping.")
continue
print("Processing", partition)
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
recordings=m["recordings"], supervisions=m["supervisions"],
)
if "train" in partition:
cut_set = (
@ -65,4 +73,10 @@ def compute_fbank_librispeech():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech()

View File

@ -2,18 +2,27 @@
"""
This file computes fbank features of the musan dataset.
Its looks for manifests in the directory data/manifests
and generated fbank features are saved in data/fbank.
Its looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan():
src_dir = Path("data/manifests")
@ -34,10 +43,10 @@ def compute_fbank_musan():
musan_cuts_path = output_dir / "cuts_musan.json.gz"
if musan_cuts_path.is_file():
print(f"{musan_cuts_path} already exists - skipping")
logging.info(f"{musan_cuts_path} already exists - skipping")
return
print("Extracting features for Musan")
logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -63,4 +72,9 @@ def compute_fbank_musan():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

View File

@ -2,10 +2,25 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This file downloads librispeech LM files to data/lm
This file downloads the following LibriSpeech LM files:
- 3-gram.pruned.1e-7.arpa.gz
- 4-gram.arpa.gz
- librispeech-vocab.txt
- librispeech-lexicon.txt
from http://www.openslr.org/resources/11
and save them in the user provided directory.
Files are not re-downloaded if they already exist.
Usage:
./local/download_lm.py --out-dir ./download/lm
"""
import argparse
import gzip
import logging
import os
import shutil
from pathlib import Path
@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress
from tqdm.auto import tqdm
def download_lm():
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", type=str, help="Output directory.")
args = parser.parse_args()
return args
def main(out_dir: str):
url = "http://www.openslr.org/resources/11"
target_dir = Path("data/lm")
out_dir = Path(out_dir)
files_to_download = (
"3-gram.pruned.1e-7.arpa.gz",
@ -26,7 +49,7 @@ def download_lm():
)
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
filename = target_dir / f
filename = out_dir / f
if filename.is_file() is False:
urlretrieve_progress(
f"{url}/{f}",
@ -34,17 +57,26 @@ def download_lm():
desc=f"Downloading {filename}",
)
else:
print(f"{filename} already exists - skipping")
logging.info(f"{filename} already exists - skipping")
if ".gz" in str(filename):
unzip_file = Path(os.path.splitext(filename)[0])
if unzip_file.is_file() is False:
unzipped = Path(os.path.splitext(filename)[0])
if unzipped.is_file() is False:
with gzip.open(filename, "rb") as f_in:
with open(unzip_file, "wb") as f_out:
with open(unzipped, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
else:
print(f"{unzip_file} already exist - skipping")
logging.info(f"{unzipped} already exist - skipping")
if __name__ == "__main__":
download_lm()
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(f"out_dir: {args.out_dir}")
main(out_dir=args.out_dir)

View File

@ -3,7 +3,7 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
"""
This script takes as input a lexicon file "data/lang/lexicon.txt"
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format.
"""
import math
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
@ -284,7 +282,9 @@ def lexicon_to_fst(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
@ -301,7 +301,7 @@ def lexicon_to_fst(
def main():
out_dir = Path("data/lang")
out_dir = Path("data/lang_phone")
lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL"
sil_prob = 0.5

View File

@ -5,10 +5,10 @@
"""
This script takes as inputs the following two files:
- data/lang/bpe/bpe.model,
- data/lang/bpe/words.txt
- data/lang_bpe/bpe.model,
- data/lang_bpe/words.txt
and generates the following files in the directory data/lang/bpe:
and generates the following files in the directory data/lang_bpe:
- lexicon.txt
- lexicon_disambig.txt
@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"]
disambig_word = word2id["#0"]
arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
)
final_state = next_state
@ -140,7 +142,7 @@ def generate_lexicon(
def main():
lang_dir = Path("data/lang/bpe")
lang_dir = Path("data/lang_bpe")
model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
@ -173,7 +175,9 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil(
lexicon, token2id=token_sym_table, word2id=word_sym_table,
lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
)
L_disambig = lexicon_to_fst_no_sil(

View File

@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model".
#
# Please install a version >=0.1.96
import shutil
from pathlib import Path
import sentencepiece as spm
import shutil
def main():
model_type = "unigram"
vocab_size = 5000
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}"
train_text = "data/lang/bpe/train.txt"
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
train_text = "data/lang_bpe/train.txt"
character_coverage = 1.0
input_sentence_size = 100000000
@ -53,7 +52,7 @@ def main():
sp = spm.SentencePieceProcessor(model_file=str(model_file))
vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang/bpe/bpe.model")
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
if __name__ == "__main__":

View File

@ -6,8 +6,38 @@ nj=15
stage=-1
stop_stage=100
. local/parse_options.sh || exit 1
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriSpeech
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
#
# - $do_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All generated files by this script are saved in "data"
mkdir -p data
log() {
@ -16,10 +46,11 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM"
mkdir -p data/lm
./local/download_lm.py
./local/download_lm.py --out-dir=$dl_dir/lm
fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink
#
# ln -sfv /path/to/LibriSpeech data/
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
#
# The script checks that if
#
# data/LibriSpeech/test-clean/.completed exists,
#
# it will not re-download it.
#
# The same goes for dev-clean, dev-other, test-other, train-clean-100
# train-clean-360, and train-other-500
mkdir -p data/LibriSpeech
lhotse download librispeech --full data
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
lhotse download librispeech --full $dl_dir
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan data/
# ln -sfv /path/to/musan $dl_dir/
#
# and create a file data/.musan_completed
# to avoid downloading it again
if [ ! -f data/.musan_completed ]; then
lhotse download musan data
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare librispeech manifest"
# We assume that you have downloaded the librispeech corpus
# to data/LibriSpeech
log "Stage 1: Prepare LibriSpeech manifest"
# We assume that you have downloaded the LibriSpeech corpus
# to $dl_dir/LibriSpeech
mkdir -p data/manifests
lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# We assume that you have downloaded the musan corpus
# to data/musan
mkdir -p data/manifests
lhotse prepare musan data/musan data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@ -84,24 +105,25 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
# TODO: add BPE based lang
mkdir -p data/lang
mkdir -p data/lang_phone
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - data/lm/librispeech-lexicon.txt |
sort | uniq > data/lang/lexicon.txt
cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > data/lang_phone/lexicon.txt
if [ ! -f data/lang/L_disambig.pt ]; then
if [ ! -f data/lang_phone/L_disambig.pt ]; then
./local/prepare_lang.py
fi
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang"
mkdir -p data/lang/bpe
cp data/lang/words.txt data/lang/bpe/
mkdir -p data/lang_bpe
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt data/lang_bpe/
if [ ! -f data/lang/bpe/train.txt ]; then
if [ ! -f data/lang_bpe/train.txt ]; then
log "Generate data for BPE training"
files=$(
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > data/lang/bpe/train.txt
done > data/lang_bpe/train.txt
fi
python3 ./local/train_bpe_model.py
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
if [ ! -f data/lang_bpe/L_disambig.pt ]; then
./local/prepare_lang_bpe.py
fi
fi
@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=3 \
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi
fi

1
egs/librispeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -58,7 +58,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
"lang_dir": Path("data/lang"),
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
"subsampling_factor": 3,
@ -328,7 +328,7 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -340,7 +340,7 @@ def main():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.words["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so

View File

@ -127,7 +127,7 @@ def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp"),
"lang_dir": Path("data/lang"),
"lang_dir": Path("data/lang_phone"),
"lr": 1e-3,
"feature_dim": 80,
"weight_decay": 5e-4,
@ -501,8 +501,9 @@ def run(rank, world_size, args):
)
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"])
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()

View File

@ -555,24 +555,31 @@ def rescore_with_attention_decoder(
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
sos_id: int,
eos_id: int,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
score is used as the decoding output.
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `[T, N, C]`.
memory_key_padding_mask:
The padding mask for memory with shape [N, T].
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
@ -661,7 +668,11 @@ def rescore_with_attention_decoder(
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs

View File

@ -1,7 +1,8 @@
import logging
import re
import sys
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Tuple
import k2
import torch
@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
continue
if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}")
print("Every line is expected to contain at least 2 fields")
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info(
"Every line is expected to contain at least 2 fields"
)
sys.exit(1)
word = a[0]
if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}")
print("<eps> should not be a valid word")
logging.info(
f"Found bad line {line} in lexicon file {filename}"
)
logging.info("<eps> should not be a valid word")
sys.exit(1)
tokens = a[1:]
@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
class Lexicon(object):
"""Phone based lexicon.
TODO: Add BpeLexicon for BPE models.
"""
"""Phone based lexicon."""
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Args:
@ -121,7 +127,9 @@ class Lexicon(object):
class BpeLexicon(Lexicon):
def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
):
"""
Refer to the help information in Lexicon.__init__.

View File

@ -225,7 +225,10 @@ def store_transcripts(
def write_error_stats(
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
) -> float:
"""Write statistics based on predicted results and reference transcripts.
@ -255,6 +258,9 @@ def write_error_stats(
results:
An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns:
Return None.
"""
@ -290,11 +296,12 @@ def write_error_stats(
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(