Merge branch 'master' into nbest

This commit is contained in:
pkufool 2021-08-04 15:52:14 +08:00
commit 286dce7b0f
23 changed files with 968 additions and 648 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ path.sh
exp exp
exp*/ exp*/
*.pt *.pt
download/

View File

@ -84,20 +84,26 @@ class Conformer(Transformer):
# and throws an error without this change. # and throws an error without this change.
self.after_norm = identity self.after_norm = identity
def encode( def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:
""" """
Args: Args:
x: Tensor of dimension (batch_size, num_features, input_length). x:
supervisions : Supervison in lhotse format, i.e., batch['supervisions'] The model input. Its shape is [N, T, C].
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns: Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length) Tensor: Mask tensor of dimension (batch_size, input_length)
""" """
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)

View File

@ -15,6 +15,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import ( from icefall.decode import (
@ -62,7 +63,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"), "lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring # - whole-lattice-rescoring
# - attention-decoder # - attention-decoder
# "method": "whole-lattice-rescoring", # "method": "whole-lattice-rescoring",
"method": "1best", "method": "attention-decoder",
# num_paths is used when method is "nbest", "nbest-rescoring", # num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder # and attention-decoder
"num_paths": 100, "num_paths": 100,
@ -100,6 +101,8 @@ def decode_one_batch(
HLG: k2.Fsa, HLG: k2.Fsa,
batch: dict, batch: dict,
lexicon: Lexicon, lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
@ -133,6 +136,10 @@ def decode_one_batch(
for the format of the `batch`. for the format of the `batch`.
lexicon: lexicon:
It contains word symbol table. It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G: G:
An LM. It is not None when params.method is "nbest-rescoring" An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG or "whole-lattice-rescoring". In general, the G in HLG
@ -147,15 +154,10 @@ def decode_one_batch(
feature = feature.to(device) feature = feature.to(device)
# at entry, feature is [N, T, C] # at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is [N, C, T] # nnet_output is [N, T, C]
nnet_output = nnet_output.permute(0, 2, 1)
# now nnet_output is [N, T, C]
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
@ -227,6 +229,8 @@ def decode_one_batch(
model=model, model=model,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -245,6 +249,8 @@ def decode_dataset(
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: k2.Fsa,
lexicon: Lexicon, lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset. """Decode dataset.
@ -260,6 +266,10 @@ def decode_dataset(
The decoding graph. The decoding graph.
lexicon: lexicon:
It contains word symbol table. It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G: G:
An LM. It is not None when params.method is "nbest-rescoring" An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG or "whole-lattice-rescoring". In general, the G in HLG
@ -287,6 +297,8 @@ def decode_dataset(
batch=batch, batch=batch,
lexicon=lexicon, lexicon=lexicon,
G=G, G=G,
sos_id=sos_id,
eos_id=eos_id,
) )
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
@ -314,20 +326,31 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results) wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
@ -367,15 +390,22 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt")) graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"): if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone() HLG.lm_scores = HLG.scores.clone()
# HLG = k2.ctc_topo(4999).to(device)
if params.method in ( if params.method in (
"nbest-rescoring", "nbest-rescoring",
"whole-lattice-rescoring", "whole-lattice-rescoring",
@ -461,6 +491,8 @@ def main():
HLG=HLG, HLG=HLG,
lexicon=lexicon, lexicon=lexicon,
G=G, G=G,
sos_id=sos_id,
eos_id=eos_id,
) )
save_results( save_results(
@ -470,5 +502,8 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -0,0 +1,144 @@
import torch
import torch.nn as nn
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(self, idim: int, odim: int) -> None:
"""
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
assert idim >= 7
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
nn.Conv2d(
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
),
nn.ReLU(),
)
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
# On entry, x is [N, T, idim]
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
x = self.conv(x)
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
return x
class VggSubsampling(nn.Module):
"""Trying to follow the setup described in the following paper:
https://arxiv.org/pdf/1910.09799.pdf
This paper is not 100% explicit so I am guessing to some extent,
and trying to compare with other VGG implementations.
Convert an input of shape [N, T, idim] to an output
with shape [N, T', odim], where
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
"""
def __init__(self, idim: int, odim: int) -> None:
"""Construct a VggSubsampling object.
This uses 2 VGG blocks with 2 Conv2d layers each,
subsampling its input by a factor of 4 in the time dimensions.
Args:
idim:
Input dim. The input shape is [N, T, idim].
Caution: It requires: T >=7, idim >=7
odim:
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
"""
super().__init__()
cur_channels = 1
layers = []
block_dims = [32, 64]
# The decision to use padding=1 for the 1st convolution, then padding=0
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
# a back-compatibility concern so that the number of frames at the
# output would be equal to:
# (((T-1)//2)-1)//2.
# We can consider changing this by using padding=1 on the
# 2nd convolution, so the num-frames at the output would be T//4.
for block_dim in block_dims:
layers.append(
torch.nn.Conv2d(
in_channels=cur_channels,
out_channels=block_dim,
kernel_size=3,
padding=1,
stride=1,
)
)
layers.append(torch.nn.ReLU())
layers.append(
torch.nn.Conv2d(
in_channels=block_dim,
out_channels=block_dim,
kernel_size=3,
padding=0,
stride=1,
)
)
layers.append(
torch.nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, ceil_mode=True
)
)
cur_channels = block_dim
self.layers = nn.Sequential(*layers)
self.out = nn.Linear(
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x:
Its shape is [N, T, idim].
Returns:
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
"""
x = x.unsqueeze(1)
x = self.layers(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
return x

View File

@ -0,0 +1,33 @@
#!/usr/bin/env python3
from subsampling import Conv2dSubsampling
from subsampling import VggSubsampling
import torch
def test_conv2d_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = Conv2dSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim
def test_vgg_subsampling():
N = 3
odim = 2
for T in range(7, 19):
for idim in range(7, 20):
model = VggSubsampling(idim=idim, odim=odim)
x = torch.empty(N, T, idim)
y = model(x)
assert y.shape[0] == N
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
assert y.shape[2] == odim

View File

@ -0,0 +1,89 @@
#!/usr/bin/env python3
import torch
from transformer import (
Transformer,
encoder_padding_mask,
generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
)
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask():
supervisions = {
"sequence_idx": torch.tensor([0, 1, 2]),
"start_frame": torch.tensor([0, 0, 0]),
"num_frames": torch.tensor([18, 7, 13]),
}
max_len = ((18 - 1) // 2 - 1) // 2
mask = encoder_padding_mask(max_len, supervisions)
expected_mask = torch.tensor(
[
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_transformer():
num_features = 40
num_classes = 87
model = Transformer(num_features=num_features, num_classes=num_classes)
N = 31
for T in range(7, 30):
x = torch.rand(N, T, num_features)
y, _, _ = model(x)
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
def test_generate_square_subsequent_mask():
s = 5
mask = generate_square_subsequent_mask(s)
inf = float("inf")
expected_mask = torch.tensor(
[
[0.0, -inf, -inf, -inf, -inf],
[0.0, 0.0, -inf, -inf, -inf],
[0.0, 0.0, 0.0, -inf, -inf],
[0.0, 0.0, 0.0, 0.0, -inf],
[0.0, 0.0, 0.0, 0.0, 0.0],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_decoder_padding_mask():
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor(
[
[False, False, True],
[False, True, True],
[False, False, False],
]
)
assert torch.all(torch.eq(mask, expected_mask))
def test_add_sos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_sos(x, sos_id=0)
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
assert y == expected_y
def test_add_eos():
x = [[1, 2], [3], [2, 5, 8]]
y = add_eos(x, eos_id=0)
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
assert y == expected_y

View File

@ -125,7 +125,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang/bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 0.0, "weight_decay": 0.0,
"subsampling_factor": 4, "subsampling_factor": 4,
@ -275,15 +275,13 @@ def compute_loss(
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is [N, T, C] # at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, C, T] # nnet_output is [N, T, C]
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by
@ -536,6 +534,22 @@ def train_one_epoch(
f" best valid loss: {params.best_valid_loss:.4f} " f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}" f"best valid epoch: {params.best_valid_epoch}"
) )
if tb_writer is not None:
tb_writer.add_scalar(
"train/valid_ctc_loss",
params.valid_ctc_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_att_loss",
params.valid_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/valid_loss",
params.valid_loss,
params.batch_idx_train,
)
params.train_loss = tot_loss / tot_frames params.train_loss = tot_loss / tot_frames
@ -675,5 +689,8 @@ def main():
run(rank=0, world_size=1, args=args) run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang or data/lang/bpe. The language directory, e.g., data/lang_phone or data/lang_bpe.
Return: Return:
An FSA representing HLG. An FSA representing HLG.
@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
logging.info("Loading G_3_gram.fst.txt") logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f: with open("data/lm/G_3_gram.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "G_3_gram.pt") torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
first_token_disambig_id = lexicon.token_table["#0"] first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"] first_word_disambig_id = lexicon.word_table["#0"]
@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
return HLG return HLG
def phone_based_HLG():
if Path("data/lm/HLG.pt").is_file():
return
logging.info("Compiling phone based HLG")
HLG = compile_HLG("data/lang")
logging.info("Saving HLG.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
def bpe_based_HLG():
if Path("data/lm/HLG_bpe.pt").is_file():
return
logging.info("Compiling BPE based HLG")
HLG = compile_HLG("data/lang/bpe")
logging.info("Saving HLG_bpe.pt to data/lm")
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
def main(): def main():
phone_based_HLG() for d in ["data/lang_phone", "data/lang_bpe"]:
bpe_based_HLG() d = Path(d)
logging.info(f"Processing {d}")
if (d / "HLG.pt").is_file():
logging.info(f"{d}/HLG.pt already exists - skipping")
continue
HLG = compile_HLG(d)
logging.info(f"Saving HLG.pt to {d}")
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,19 +1,28 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This file computes fbank features of the librispeech dataset. This file computes fbank features of the LibriSpeech dataset.
Its looks for manifests in the directory data/manifests Its looks for manifests in the directory data/manifests.
and generated fbank features are saved in data/fbank.
The generated fbank features are saved in data/fbank.
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_librispeech(): def compute_fbank_librispeech():
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
@ -40,12 +49,11 @@ def compute_fbank_librispeech():
with get_executor() as ex: # Initialize the executor only once. with get_executor() as ex: # Initialize the executor only once.
for partition, m in manifests.items(): for partition, m in manifests.items():
if (output_dir / f"cuts_{partition}.json.gz").is_file(): if (output_dir / f"cuts_{partition}.json.gz").is_file():
print(f"{partition} already exists - skipping.") logging.info(f"{partition} already exists - skipping.")
continue continue
print("Processing", partition) logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests( cut_set = CutSet.from_manifests(
recordings=m["recordings"], recordings=m["recordings"], supervisions=m["supervisions"],
supervisions=m["supervisions"],
) )
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
@ -65,4 +73,10 @@ def compute_fbank_librispeech():
if __name__ == "__main__": if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech() compute_fbank_librispeech()

View File

@ -2,18 +2,27 @@
""" """
This file computes fbank features of the musan dataset. This file computes fbank features of the musan dataset.
Its looks for manifests in the directory data/manifests Its looks for manifests in the directory data/manifests.
and generated fbank features are saved in data/fbank.
The generated fbank features are saved in data/fbank.
""" """
import logging
import os import os
from pathlib import Path from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
# slow things down. Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan(): def compute_fbank_musan():
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
@ -34,10 +43,10 @@ def compute_fbank_musan():
musan_cuts_path = output_dir / "cuts_musan.json.gz" musan_cuts_path = output_dir / "cuts_musan.json.gz"
if musan_cuts_path.is_file(): if musan_cuts_path.is_file():
print(f"{musan_cuts_path} already exists - skipping") logging.info(f"{musan_cuts_path} already exists - skipping")
return return
print("Extracting features for Musan") logging.info("Extracting features for Musan")
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -63,4 +72,9 @@ def compute_fbank_musan():
if __name__ == "__main__": if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan() compute_fbank_musan()

View File

@ -2,10 +2,25 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
""" """
This file downloads librispeech LM files to data/lm This file downloads the following LibriSpeech LM files:
- 3-gram.pruned.1e-7.arpa.gz
- 4-gram.arpa.gz
- librispeech-vocab.txt
- librispeech-lexicon.txt
from http://www.openslr.org/resources/11
and save them in the user provided directory.
Files are not re-downloaded if they already exist.
Usage:
./local/download_lm.py --out-dir ./download/lm
""" """
import argparse
import gzip import gzip
import logging
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress
from tqdm.auto import tqdm from tqdm.auto import tqdm
def download_lm(): def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", type=str, help="Output directory.")
args = parser.parse_args()
return args
def main(out_dir: str):
url = "http://www.openslr.org/resources/11" url = "http://www.openslr.org/resources/11"
target_dir = Path("data/lm") out_dir = Path(out_dir)
files_to_download = ( files_to_download = (
"3-gram.pruned.1e-7.arpa.gz", "3-gram.pruned.1e-7.arpa.gz",
@ -26,7 +49,7 @@ def download_lm():
) )
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
filename = target_dir / f filename = out_dir / f
if filename.is_file() is False: if filename.is_file() is False:
urlretrieve_progress( urlretrieve_progress(
f"{url}/{f}", f"{url}/{f}",
@ -34,17 +57,26 @@ def download_lm():
desc=f"Downloading {filename}", desc=f"Downloading {filename}",
) )
else: else:
print(f"{filename} already exists - skipping") logging.info(f"{filename} already exists - skipping")
if ".gz" in str(filename): if ".gz" in str(filename):
unzip_file = Path(os.path.splitext(filename)[0]) unzipped = Path(os.path.splitext(filename)[0])
if unzip_file.is_file() is False: if unzipped.is_file() is False:
with gzip.open(filename, "rb") as f_in: with gzip.open(filename, "rb") as f_in:
with open(unzip_file, "wb") as f_out: with open(unzipped, "wb") as f_out:
shutil.copyfileobj(f_in, f_out) shutil.copyfileobj(f_in, f_out)
else: else:
print(f"{unzip_file} already exist - skipping") logging.info(f"{unzipped} already exist - skipping")
if __name__ == "__main__": if __name__ == "__main__":
download_lm() formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
logging.info(f"out_dir: {args.out_dir}")
main(out_dir=args.out_dir)

View File

@ -3,7 +3,7 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
""" """
This script takes as input a lexicon file "data/lang/lexicon.txt" This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
consisting of words and tokens (i.e., phones) and does the following: consisting of words and tokens (i.e., phones) and does the following:
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following:
5. Generate L_disambig.pt, in k2 format. 5. Generate L_disambig.pt, in k2 format.
""" """
import math import math
import re
import sys
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
@ -284,7 +282,9 @@ def lexicon_to_fst(
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
disambig_word = word2id["#0"] disambig_word = word2id["#0"]
arcs = add_self_loops( arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word, arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
) )
final_state = next_state final_state = next_state
@ -301,7 +301,7 @@ def lexicon_to_fst(
def main(): def main():
out_dir = Path("data/lang") out_dir = Path("data/lang_phone")
lexicon_filename = out_dir / "lexicon.txt" lexicon_filename = out_dir / "lexicon.txt"
sil_token = "SIL" sil_token = "SIL"
sil_prob = 0.5 sil_prob = 0.5

View File

@ -5,10 +5,10 @@
""" """
This script takes as inputs the following two files: This script takes as inputs the following two files:
- data/lang/bpe/bpe.model, - data/lang_bpe/bpe.model,
- data/lang/bpe/words.txt - data/lang_bpe/words.txt
and generates the following files in the directory data/lang/bpe: and generates the following files in the directory data/lang_bpe:
- lexicon.txt - lexicon.txt
- lexicon_disambig.txt - lexicon_disambig.txt
@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil(
disambig_token = token2id["#0"] disambig_token = token2id["#0"]
disambig_word = word2id["#0"] disambig_word = word2id["#0"]
arcs = add_self_loops( arcs = add_self_loops(
arcs, disambig_token=disambig_token, disambig_word=disambig_word, arcs,
disambig_token=disambig_token,
disambig_word=disambig_word,
) )
final_state = next_state final_state = next_state
@ -140,7 +142,7 @@ def generate_lexicon(
def main(): def main():
lang_dir = Path("data/lang/bpe") lang_dir = Path("data/lang_bpe")
model_file = lang_dir / "bpe.model" model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
@ -173,7 +175,9 @@ def main():
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
L = lexicon_to_fst_no_sil( L = lexicon_to_fst_no_sil(
lexicon, token2id=token_sym_table, word2id=word_sym_table, lexicon,
token2id=token_sym_table,
word2id=word_sym_table,
) )
L_disambig = lexicon_to_fst_no_sil( L_disambig = lexicon_to_fst_no_sil(

View File

@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model".
# #
# Please install a version >=0.1.96 # Please install a version >=0.1.96
import shutil
from pathlib import Path from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import shutil
def main(): def main():
model_type = "unigram" model_type = "unigram"
vocab_size = 5000 vocab_size = 5000
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}" model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
train_text = "data/lang/bpe/train.txt" train_text = "data/lang_bpe/train.txt"
character_coverage = 1.0 character_coverage = 1.0
input_sentence_size = 100000000 input_sentence_size = 100000000
@ -53,7 +52,7 @@ def main():
sp = spm.SentencePieceProcessor(model_file=str(model_file)) sp = spm.SentencePieceProcessor(model_file=str(model_file))
vocab_size = sp.vocab_size() vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang/bpe/bpe.model") shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -6,8 +6,38 @@ nj=15
stage=-1 stage=-1
stop_stage=100 stop_stage=100
. local/parse_options.sh || exit 1 # We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/LibriSpeech
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
# You can download them from https://www.openslr.org/12
#
# - $dl_dir/lm
# This directory contains the following files downloaded from
# http://www.openslr.org/resources/11
#
# - 3-gram.pruned.1e-7.arpa.gz
# - 3-gram.pruned.1e-7.arpa
# - 4-gram.arpa.gz
# - 4-gram.arpa
# - librispeech-vocab.txt
# - librispeech-lexicon.txt
#
# - $do_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All generated files by this script are saved in "data"
mkdir -p data mkdir -p data
log() { log() {
@ -16,10 +46,11 @@ log() {
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
} }
log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM" log "stage -1: Download LM"
mkdir -p data/lm ./local/download_lm.py --out-dir=$dl_dir/lm
./local/download_lm.py
fi fi
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# If you have pre-downloaded it to /path/to/LibriSpeech, # If you have pre-downloaded it to /path/to/LibriSpeech,
# you can create a symlink # you can create a symlink
# #
# ln -sfv /path/to/LibriSpeech data/ # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
# #
# The script checks that if if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
# lhotse download librispeech --full $dl_dir
# data/LibriSpeech/test-clean/.completed exists, fi
#
# it will not re-download it.
#
# The same goes for dev-clean, dev-other, test-other, train-clean-100
# train-clean-360, and train-other-500
mkdir -p data/LibriSpeech
lhotse download librispeech --full data
# If you have pre-downloaded it to /path/to/musan, # If you have pre-downloaded it to /path/to/musan,
# you can create a symlink # you can create a symlink
# #
# ln -sfv /path/to/musan data/ # ln -sfv /path/to/musan $dl_dir/
# #
# and create a file data/.musan_completed if [ ! -d $dl_dir/musan ]; then
# to avoid downloading it again lhotse download musan $dl_dir
if [ ! -f data/.musan_completed ]; then
lhotse download musan data
fi fi
fi fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare librispeech manifest" log "Stage 1: Prepare LibriSpeech manifest"
# We assume that you have downloaded the librispeech corpus # We assume that you have downloaded the LibriSpeech corpus
# to data/LibriSpeech # to $dl_dir/LibriSpeech
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# We assume that you have downloaded the musan corpus # We assume that you have downloaded the musan corpus
# to data/musan # to data/musan
mkdir -p data/manifests mkdir -p data/manifests
lhotse prepare musan data/musan data/manifests lhotse prepare musan $dl_dir/musan data/manifests
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
@ -84,24 +105,25 @@ fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang" log "Stage 5: Prepare phone based lang"
# TODO: add BPE based lang mkdir -p data/lang_phone
mkdir -p data/lang
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) | (echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
cat - data/lm/librispeech-lexicon.txt | cat - $dl_dir/lm/librispeech-lexicon.txt |
sort | uniq > data/lang/lexicon.txt sort | uniq > data/lang_phone/lexicon.txt
if [ ! -f data/lang/L_disambig.pt ]; then if [ ! -f data/lang_phone/L_disambig.pt ]; then
./local/prepare_lang.py ./local/prepare_lang.py
fi fi
fi fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang" log "State 6: Prepare BPE based lang"
mkdir -p data/lang/bpe mkdir -p data/lang_bpe
cp data/lang/words.txt data/lang/bpe/ # We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt data/lang_bpe/
if [ ! -f data/lang/bpe/train.txt ]; then if [ ! -f data/lang_bpe/train.txt ]; then
log "Generate data for BPE training" log "Generate data for BPE training"
files=$( files=$(
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
) )
for f in ${files[@]}; do for f in ${files[@]}; do
cat $f | cut -d " " -f 2- cat $f | cut -d " " -f 2-
done > data/lang/bpe/train.txt done > data/lang_bpe/train.txt
fi fi
python3 ./local/train_bpe_model.py python3 ./local/train_bpe_model.py
if [ ! -f data/lang/bpe/L_disambig.pt ]; then if [ ! -f data/lang_bpe/L_disambig.pt ]; then
./local/prepare_lang_bpe.py ./local/prepare_lang_bpe.py
fi fi
fi fi
@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
# We assume you have install kaldilm, if not, please install # We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm # it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_3_gram.fst.txt ]; then if [ ! -f data/lm/G_3_gram.fst.txt ]; then
# It is used in building HLG # It is used in building HLG
python3 -m kaldilm \ python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \ --read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \ --disambig-symbol='#0' \
--max-order=3 \ --max-order=3 \
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
fi fi
if [ ! -f data/lm/G_4_gram.fst.txt ]; then if [ ! -f data/lm/G_4_gram.fst.txt ]; then
# It is used for LM rescoring # It is used for LM rescoring
python3 -m kaldilm \ python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \ --read-symbol-table="data/lang_phone/words.txt" \
--disambig-symbol='#0' \ --disambig-symbol='#0' \
--max-order=4 \ --max-order=4 \
data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
fi fi
fi fi

1
egs/librispeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -58,7 +58,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("tdnn_lstm_ctc/exp/"), "exp_dir": Path("tdnn_lstm_ctc/exp/"),
"lang_dir": Path("data/lang"), "lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 3, "subsampling_factor": 3,
@ -328,7 +328,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt")) HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -340,7 +340,7 @@ def main():
logging.info("Loading G_4_gram.fst.txt") logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.") logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f: with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.words["#0"] first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so # G.aux_labels is not needed in later computations, so

View File

@ -127,7 +127,7 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("tdnn_lstm_ctc/exp"), "exp_dir": Path("tdnn_lstm_ctc/exp"),
"lang_dir": Path("data/lang"), "lang_dir": Path("data/lang_phone"),
"lr": 1e-3, "lr": 1e-3,
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 5e-4, "weight_decay": 5e-4,
@ -501,6 +501,7 @@ def run(rank, world_size, args):
) )
scheduler = StepLR(optimizer, step_size=8, gamma=0.1) scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])

View File

@ -555,11 +555,14 @@ def rescore_with_attention_decoder(
model: nn.Module, model: nn.Module,
memory: torch.Tensor, memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor, memory_key_padding_mask: torch.Tensor,
sos_id: int,
eos_id: int,
) -> Dict[str, k2.Fsa]: ) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses """This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest an attention decoder to rescore them. The path with the highest
score is used as the decoding output. score is used as the decoding output.
Args:
lattice: lattice:
An FsaVec. It can be the return value of :func:`get_lattice`. An FsaVec. It can be the return value of :func:`get_lattice`.
num_paths: num_paths:
@ -573,6 +576,10 @@ def rescore_with_attention_decoder(
Its shape is `[T, N, C]`. Its shape is `[T, N, C]`.
memory_key_padding_mask: memory_key_padding_mask:
The padding mask for memory with shape [N, T]. The padding mask for memory with shape [N, T].
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns: Returns:
A dict of FsaVec, whose key contains a string A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the ngram_lm_scale_attention_scale and the value is the
@ -661,7 +668,11 @@ def rescore_with_attention_decoder(
# TODO: pass the sos_token_id and eos_token_id via function arguments # TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll( nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1 memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
) )
assert nll.ndim == 2 assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs assert nll.shape[0] == num_word_seqs

View File

@ -1,7 +1,8 @@
import logging import logging
import re import re
import sys
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple
import k2 import k2
import torch import torch
@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
continue continue
if len(a) < 2: if len(a) < 2:
print(f"Found bad line {line} in lexicon file {filename}") logging.info(
print("Every line is expected to contain at least 2 fields") f"Found bad line {line} in lexicon file {filename}"
)
logging.info(
"Every line is expected to contain at least 2 fields"
)
sys.exit(1) sys.exit(1)
word = a[0] word = a[0]
if word == "<eps>": if word == "<eps>":
print(f"Found bad line {line} in lexicon file {filename}") logging.info(
print("<eps> should not be a valid word") f"Found bad line {line} in lexicon file {filename}"
)
logging.info("<eps> should not be a valid word")
sys.exit(1) sys.exit(1)
tokens = a[1:] tokens = a[1:]
@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
class Lexicon(object): class Lexicon(object):
"""Phone based lexicon. """Phone based lexicon."""
TODO: Add BpeLexicon for BPE models.
"""
def __init__( def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
): ):
""" """
Args: Args:
@ -121,7 +127,9 @@ class Lexicon(object):
class BpeLexicon(Lexicon): class BpeLexicon(Lexicon):
def __init__( def __init__(
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), self,
lang_dir: Path,
disambig_pattern: str = re.compile(r"^#\d+$"),
): ):
""" """
Refer to the help information in Lexicon.__init__. Refer to the help information in Lexicon.__init__.

View File

@ -226,7 +226,10 @@ def store_transcripts(
def write_error_stats( def write_error_stats(
f: TextIO, test_set_name: str, results: List[Tuple[str, str]] f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
) -> float: ) -> float:
"""Write statistics based on predicted results and reference transcripts. """Write statistics based on predicted results and reference transcripts.
@ -256,6 +259,9 @@ def write_error_stats(
results: results:
An iterable of tuples. The first element is the reference transcript An iterable of tuples. The first element is the reference transcript
while the second element is the predicted result. while the second element is the predicted result.
enable_log:
If True, also print detailed WER to the console.
Otherwise, it is written only to the given file.
Returns: Returns:
Return None. Return None.
""" """
@ -291,6 +297,7 @@ def write_error_stats(
tot_errs = sub_errs + ins_errs + del_errs tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info( logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, " f"[{tot_errs} / {ref_len}, {ins_errs} ins, "