mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Refactoring (#4)
* Fix an error in TDNN-LSTM training. * WIP: Refactoring * Refactor transformer.py * Remove unused code. * Minor fixes.
This commit is contained in:
parent
cf8d76293d
commit
5a0b9bcb23
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ path.sh
|
||||
exp
|
||||
exp*/
|
||||
*.pt
|
||||
download/
|
||||
|
@ -84,20 +84,26 @@ class Conformer(Transformer):
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
def encode(
|
||||
def run_encoder(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x: Tensor of dimension (batch_size, num_features, input_length).
|
||||
supervisions : Supervison in lhotse format, i.e., batch['supervisions']
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
@ -15,6 +15,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import (
|
||||
@ -62,7 +63,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
@ -85,7 +86,7 @@ def get_params() -> AttributeDict:
|
||||
# - whole-lattice-rescoring
|
||||
# - attention-decoder
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "1best",
|
||||
"method": "attention-decoder",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# and attention-decoder
|
||||
"num_paths": 100,
|
||||
@ -100,6 +101,8 @@ def decode_one_batch(
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -133,6 +136,10 @@ def decode_one_batch(
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -147,15 +154,10 @@ def decode_one_batch(
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, C, T]
|
||||
|
||||
nnet_output = nnet_output.permute(0, 2, 1)
|
||||
# now nnet_output is [N, T, C]
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -227,6 +229,8 @@ def decode_one_batch(
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
@ -245,6 +249,8 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||
"""Decode dataset.
|
||||
@ -260,6 +266,10 @@ def decode_dataset(
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -287,6 +297,8 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
@ -314,20 +326,31 @@ def save_results(
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
if params.method == "attention-decoder":
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
@ -367,15 +390,22 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt"))
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
# HLG = k2.ctc_topo(4999).to(device)
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
@ -461,6 +491,8 @@ def main():
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
@ -470,5 +502,8 @@ def main():
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
144
egs/librispeech/ASR/conformer_ctc/subsampling.py
Normal file
144
egs/librispeech/ASR/conformer_ctc/subsampling.py
Normal file
@ -0,0 +1,144 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
# On entry, x is [N, T, idim]
|
||||
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
||||
x = self.conv(x)
|
||||
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
||||
return x
|
||||
|
||||
|
||||
class VggSubsampling(nn.Module):
|
||||
"""Trying to follow the setup described in the following paper:
|
||||
https://arxiv.org/pdf/1910.09799.pdf
|
||||
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""Construct a VggSubsampling object.
|
||||
|
||||
This uses 2 VGG blocks with 2 Conv2d layers each,
|
||||
subsampling its input by a factor of 4 in the time dimensions.
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
cur_channels = 1
|
||||
layers = []
|
||||
block_dims = [32, 64]
|
||||
|
||||
# The decision to use padding=1 for the 1st convolution, then padding=0
|
||||
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
||||
# a back-compatibility concern so that the number of frames at the
|
||||
# output would be equal to:
|
||||
# (((T-1)//2)-1)//2.
|
||||
# We can consider changing this by using padding=1 on the
|
||||
# 2nd convolution, so the num-frames at the output would be T//4.
|
||||
for block_dim in block_dims:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=cur_channels,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(torch.nn.ReLU())
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=block_dim,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=0,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||
)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
return x
|
33
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Executable file
33
egs/librispeech/ASR/conformer_ctc/test_subsampling.py
Executable file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from subsampling import Conv2dSubsampling
|
||||
from subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = Conv2dSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
||||
|
||||
|
||||
def test_vgg_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = VggSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
89
egs/librispeech/ASR/conformer_ctc/test_transformer.py
Normal file
89
egs/librispeech/ASR/conformer_ctc/test_transformer.py
Normal file
@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from transformer import (
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
decoder_padding_mask,
|
||||
add_sos,
|
||||
add_eos,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
supervisions = {
|
||||
"sequence_idx": torch.tensor([0, 1, 2]),
|
||||
"start_frame": torch.tensor([0, 0, 0]),
|
||||
"num_frames": torch.tensor([18, 7, 13]),
|
||||
}
|
||||
|
||||
max_len = ((18 - 1) // 2 - 1) // 2
|
||||
mask = encoder_padding_mask(max_len, supervisions)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
|
||||
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
|
||||
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_transformer():
|
||||
num_features = 40
|
||||
num_classes = 87
|
||||
model = Transformer(num_features=num_features, num_classes=num_classes)
|
||||
|
||||
N = 31
|
||||
|
||||
for T in range(7, 30):
|
||||
x = torch.rand(N, T, num_features)
|
||||
y, _, _ = model(x)
|
||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
||||
|
||||
|
||||
def test_generate_square_subsequent_mask():
|
||||
s = 5
|
||||
mask = generate_square_subsequent_mask(s)
|
||||
inf = float("inf")
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_decoder_padding_mask():
|
||||
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
|
||||
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, True],
|
||||
[False, True, True],
|
||||
[False, False, False],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_add_sos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_sos(x, sos_id=0)
|
||||
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
|
||||
assert y == expected_y
|
||||
|
||||
|
||||
def test_add_eos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_eos(x, eos_id=0)
|
||||
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
|
||||
assert y == expected_y
|
@ -125,7 +125,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 0.0,
|
||||
"subsampling_factor": 4,
|
||||
@ -275,15 +275,13 @@ def compute_loss(
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is [N, T, C]
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, C, T]
|
||||
nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
# nnet_output is [N, T, C]
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
@ -536,6 +534,22 @@ def train_one_epoch(
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_ctc_loss",
|
||||
params.valid_ctc_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_att_loss",
|
||||
params.valid_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = tot_loss / tot_frames
|
||||
|
||||
@ -675,5 +689,8 @@ def main():
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang or data/lang/bpe.
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "G_3_gram.pt")
|
||||
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
return HLG
|
||||
|
||||
|
||||
def phone_based_HLG():
|
||||
if Path("data/lm/HLG.pt").is_file():
|
||||
return
|
||||
|
||||
logging.info("Compiling phone based HLG")
|
||||
HLG = compile_HLG("data/lang")
|
||||
|
||||
logging.info("Saving HLG.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG.pt")
|
||||
|
||||
|
||||
def bpe_based_HLG():
|
||||
if Path("data/lm/HLG_bpe.pt").is_file():
|
||||
return
|
||||
|
||||
logging.info("Compiling BPE based HLG")
|
||||
HLG = compile_HLG("data/lang/bpe")
|
||||
logging.info("Saving HLG_bpe.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
|
||||
|
||||
|
||||
def main():
|
||||
phone_based_HLG()
|
||||
bpe_based_HLG()
|
||||
for d in ["data/lang_phone", "data/lang_bpe"]:
|
||||
d = Path(d)
|
||||
logging.info(f"Processing {d}")
|
||||
|
||||
if (d / "HLG.pt").is_file():
|
||||
logging.info(f"{d}/HLG.pt already exists - skipping")
|
||||
continue
|
||||
|
||||
HLG = compile_HLG(d)
|
||||
logging.info(f"Saving HLG.pt to {d}")
|
||||
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,19 +1,28 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This file computes fbank features of the librispeech dataset.
|
||||
Its looks for manifests in the directory data/manifests
|
||||
and generated fbank features are saved in data/fbank.
|
||||
This file computes fbank features of the LibriSpeech dataset.
|
||||
Its looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
|
||||
# slow things down. Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_librispeech():
|
||||
src_dir = Path("data/manifests")
|
||||
@ -40,12 +49,11 @@ def compute_fbank_librispeech():
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
for partition, m in manifests.items():
|
||||
if (output_dir / f"cuts_{partition}.json.gz").is_file():
|
||||
print(f"{partition} already exists - skipping.")
|
||||
logging.info(f"{partition} already exists - skipping.")
|
||||
continue
|
||||
print("Processing", partition)
|
||||
logging.info(f"Processing {partition}")
|
||||
cut_set = CutSet.from_manifests(
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
recordings=m["recordings"], supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
@ -65,4 +73,10 @@ def compute_fbank_librispeech():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_librispeech()
|
||||
|
@ -2,18 +2,27 @@
|
||||
|
||||
"""
|
||||
This file computes fbank features of the musan dataset.
|
||||
Its looks for manifests in the directory data/manifests
|
||||
and generated fbank features are saved in data/fbank.
|
||||
Its looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and
|
||||
# slow things down. Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_musan():
|
||||
src_dir = Path("data/manifests")
|
||||
@ -34,10 +43,10 @@ def compute_fbank_musan():
|
||||
musan_cuts_path = output_dir / "cuts_musan.json.gz"
|
||||
|
||||
if musan_cuts_path.is_file():
|
||||
print(f"{musan_cuts_path} already exists - skipping")
|
||||
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||
return
|
||||
|
||||
print("Extracting features for Musan")
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -63,4 +72,9 @@ def compute_fbank_musan():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan()
|
||||
|
@ -2,10 +2,25 @@
|
||||
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
"""
|
||||
This file downloads librispeech LM files to data/lm
|
||||
This file downloads the following LibriSpeech LM files:
|
||||
|
||||
- 3-gram.pruned.1e-7.arpa.gz
|
||||
- 4-gram.arpa.gz
|
||||
- librispeech-vocab.txt
|
||||
- librispeech-lexicon.txt
|
||||
|
||||
from http://www.openslr.org/resources/11
|
||||
and save them in the user provided directory.
|
||||
|
||||
Files are not re-downloaded if they already exist.
|
||||
|
||||
Usage:
|
||||
./local/download_lm.py --out-dir ./download/lm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
@ -14,9 +29,17 @@ from lhotse.utils import urlretrieve_progress
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def download_lm():
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--out-dir", type=str, help="Output directory.")
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(out_dir: str):
|
||||
url = "http://www.openslr.org/resources/11"
|
||||
target_dir = Path("data/lm")
|
||||
out_dir = Path(out_dir)
|
||||
|
||||
files_to_download = (
|
||||
"3-gram.pruned.1e-7.arpa.gz",
|
||||
@ -26,7 +49,7 @@ def download_lm():
|
||||
)
|
||||
|
||||
for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"):
|
||||
filename = target_dir / f
|
||||
filename = out_dir / f
|
||||
if filename.is_file() is False:
|
||||
urlretrieve_progress(
|
||||
f"{url}/{f}",
|
||||
@ -34,17 +57,26 @@ def download_lm():
|
||||
desc=f"Downloading {filename}",
|
||||
)
|
||||
else:
|
||||
print(f"{filename} already exists - skipping")
|
||||
logging.info(f"{filename} already exists - skipping")
|
||||
|
||||
if ".gz" in str(filename):
|
||||
unzip_file = Path(os.path.splitext(filename)[0])
|
||||
if unzip_file.is_file() is False:
|
||||
unzipped = Path(os.path.splitext(filename)[0])
|
||||
if unzipped.is_file() is False:
|
||||
with gzip.open(filename, "rb") as f_in:
|
||||
with open(unzip_file, "wb") as f_out:
|
||||
with open(unzipped, "wb") as f_out:
|
||||
shutil.copyfileobj(f_in, f_out)
|
||||
else:
|
||||
print(f"{unzip_file} already exist - skipping")
|
||||
logging.info(f"{unzipped} already exist - skipping")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
download_lm()
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
args = get_args()
|
||||
logging.info(f"out_dir: {args.out_dir}")
|
||||
|
||||
main(out_dir=args.out_dir)
|
||||
|
@ -3,7 +3,7 @@
|
||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang/lexicon.txt"
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
@ -20,8 +20,6 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
@ -284,7 +282,9 @@ def lexicon_to_fst(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -301,7 +301,7 @@ def lexicon_to_fst(
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lang")
|
||||
out_dir = Path("data/lang_phone")
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
@ -5,10 +5,10 @@
|
||||
"""
|
||||
This script takes as inputs the following two files:
|
||||
|
||||
- data/lang/bpe/bpe.model,
|
||||
- data/lang/bpe/words.txt
|
||||
- data/lang_bpe/bpe.model,
|
||||
- data/lang_bpe/words.txt
|
||||
|
||||
and generates the following files in the directory data/lang/bpe:
|
||||
and generates the following files in the directory data/lang_bpe:
|
||||
|
||||
- lexicon.txt
|
||||
- lexicon_disambig.txt
|
||||
@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil(
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs, disambig_token=disambig_token, disambig_word=disambig_word,
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
@ -140,7 +142,7 @@ def generate_lexicon(
|
||||
|
||||
|
||||
def main():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
lang_dir = Path("data/lang_bpe")
|
||||
model_file = lang_dir / "bpe.model"
|
||||
|
||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
@ -173,7 +175,9 @@ def main():
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst_no_sil(
|
||||
lexicon, token2id=token_sym_table, word2id=word_sym_table,
|
||||
lexicon,
|
||||
token2id=token_sym_table,
|
||||
word2id=word_sym_table,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst_no_sil(
|
||||
|
@ -14,18 +14,17 @@ and generates "data/lang/bpe/bep.model".
|
||||
#
|
||||
# Please install a version >=0.1.96
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
import shutil
|
||||
|
||||
|
||||
def main():
|
||||
model_type = "unigram"
|
||||
vocab_size = 5000
|
||||
model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}"
|
||||
train_text = "data/lang/bpe/train.txt"
|
||||
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
|
||||
train_text = "data/lang_bpe/train.txt"
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
@ -53,7 +52,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor(model_file=str(model_file))
|
||||
vocab_size = sp.vocab_size()
|
||||
|
||||
shutil.copyfile(model_file, "data/lang/bpe/bpe.model")
|
||||
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -6,8 +6,38 @@ nj=15
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
. local/parse_options.sh || exit 1
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/LibriSpeech
|
||||
# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it.
|
||||
# You can download them from https://www.openslr.org/12
|
||||
#
|
||||
# - $dl_dir/lm
|
||||
# This directory contains the following files downloaded from
|
||||
# http://www.openslr.org/resources/11
|
||||
#
|
||||
# - 3-gram.pruned.1e-7.arpa.gz
|
||||
# - 3-gram.pruned.1e-7.arpa
|
||||
# - 4-gram.arpa.gz
|
||||
# - 4-gram.arpa
|
||||
# - librispeech-vocab.txt
|
||||
# - librispeech-lexicon.txt
|
||||
#
|
||||
# - $do_dir/musan
|
||||
# This directory contains the following directories downloaded from
|
||||
# http://www.openslr.org/17/
|
||||
#
|
||||
# - music
|
||||
# - noise
|
||||
# - speech
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
|
||||
# All generated files by this script are saved in "data"
|
||||
mkdir -p data
|
||||
|
||||
log() {
|
||||
@ -16,10 +46,11 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: Download LM"
|
||||
mkdir -p data/lm
|
||||
./local/download_lm.py
|
||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
# If you have pre-downloaded it to /path/to/LibriSpeech,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/LibriSpeech data/
|
||||
# ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech
|
||||
#
|
||||
# The script checks that if
|
||||
#
|
||||
# data/LibriSpeech/test-clean/.completed exists,
|
||||
#
|
||||
# it will not re-download it.
|
||||
#
|
||||
# The same goes for dev-clean, dev-other, test-other, train-clean-100
|
||||
# train-clean-360, and train-other-500
|
||||
|
||||
mkdir -p data/LibriSpeech
|
||||
lhotse download librispeech --full data
|
||||
if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then
|
||||
lhotse download librispeech --full $dl_dir
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/musan,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/musan data/
|
||||
# ln -sfv /path/to/musan $dl_dir/
|
||||
#
|
||||
# and create a file data/.musan_completed
|
||||
# to avoid downloading it again
|
||||
if [ ! -f data/.musan_completed ]; then
|
||||
lhotse download musan data
|
||||
if [ ! -d $dl_dir/musan ]; then
|
||||
lhotse download musan $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare librispeech manifest"
|
||||
# We assume that you have downloaded the librispeech corpus
|
||||
# to data/LibriSpeech
|
||||
log "Stage 1: Prepare LibriSpeech manifest"
|
||||
# We assume that you have downloaded the LibriSpeech corpus
|
||||
# to $dl_dir/LibriSpeech
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests
|
||||
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan data/musan data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
@ -84,24 +105,25 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
# TODO: add BPE based lang
|
||||
mkdir -p data/lang
|
||||
mkdir -p data/lang_phone
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - data/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > data/lang/lexicon.txt
|
||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > data/lang_phone/lexicon.txt
|
||||
|
||||
if [ ! -f data/lang/L_disambig.pt ]; then
|
||||
if [ ! -f data/lang_phone/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "State 6: Prepare BPE based lang"
|
||||
mkdir -p data/lang/bpe
|
||||
cp data/lang/words.txt data/lang/bpe/
|
||||
mkdir -p data/lang_bpe
|
||||
# We reuse words.txt from phone based lexicon
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone/words.txt data/lang_bpe/
|
||||
|
||||
if [ ! -f data/lang/bpe/train.txt ]; then
|
||||
if [ ! -f data/lang_bpe/train.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
files=$(
|
||||
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||
@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
)
|
||||
for f in ${files[@]}; do
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > data/lang/bpe/train.txt
|
||||
done > data/lang_bpe/train.txt
|
||||
fi
|
||||
|
||||
python3 ./local/train_bpe_model.py
|
||||
|
||||
if [ ! -f data/lang/bpe/L_disambig.pt ]; then
|
||||
if [ ! -f data/lang_bpe/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py
|
||||
fi
|
||||
fi
|
||||
@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
mkdir -p data/lm
|
||||
if [ ! -f data/lm/G_3_gram.fst.txt ]; then
|
||||
# It is used in building HLG
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang/words.txt" \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=3 \
|
||||
data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||
$dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt
|
||||
fi
|
||||
|
||||
if [ ! -f data/lm/G_4_gram.fst.txt ]; then
|
||||
# It is used for LM rescoring
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="data/lang/words.txt" \
|
||||
--read-symbol-table="data/lang_phone/words.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=4 \
|
||||
data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
|
||||
$dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt
|
||||
fi
|
||||
fi
|
||||
|
||||
|
1
egs/librispeech/ASR/shared
Symbolic link
1
egs/librispeech/ASR/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared/
|
@ -58,7 +58,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp/"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 3,
|
||||
@ -328,7 +328,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
@ -340,7 +340,7 @@ def main():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
|
@ -127,7 +127,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
||||
"lang_dir": Path("data/lang"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lr": 1e-3,
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 5e-4,
|
||||
@ -501,8 +501,9 @@ def run(rank, world_size, args):
|
||||
)
|
||||
scheduler = StepLR(optimizer, step_size=8, gamma=0.1)
|
||||
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
|
@ -555,24 +555,31 @@ def rescore_with_attention_decoder(
|
||||
model: nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts n paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest
|
||||
score is used as the decoding output.
|
||||
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `[T, N, C]`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape [N, T].
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec. It can be the return value of :func:`get_lattice`.
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `[T, N, C]`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape [N, T].
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
@ -661,7 +668,11 @@ def rescore_with_attention_decoder(
|
||||
|
||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||
nll = model.decoder_nll(
|
||||
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == num_word_seqs
|
||||
|
@ -1,7 +1,8 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]:
|
||||
continue
|
||||
|
||||
if len(a) < 2:
|
||||
print(f"Found bad line {line} in lexicon file {filename}")
|
||||
print("Every line is expected to contain at least 2 fields")
|
||||
logging.info(
|
||||
f"Found bad line {line} in lexicon file {filename}"
|
||||
)
|
||||
logging.info(
|
||||
"Every line is expected to contain at least 2 fields"
|
||||
)
|
||||
sys.exit(1)
|
||||
word = a[0]
|
||||
if word == "<eps>":
|
||||
print(f"Found bad line {line} in lexicon file {filename}")
|
||||
print("<eps> should not be a valid word")
|
||||
logging.info(
|
||||
f"Found bad line {line} in lexicon file {filename}"
|
||||
)
|
||||
logging.info("<eps> should not be a valid word")
|
||||
sys.exit(1)
|
||||
|
||||
tokens = a[1:]
|
||||
@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
|
||||
|
||||
|
||||
class Lexicon(object):
|
||||
"""Phone based lexicon.
|
||||
|
||||
TODO: Add BpeLexicon for BPE models.
|
||||
"""
|
||||
"""Phone based lexicon."""
|
||||
|
||||
def __init__(
|
||||
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
self,
|
||||
lang_dir: Path,
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -121,7 +127,9 @@ class Lexicon(object):
|
||||
|
||||
class BpeLexicon(Lexicon):
|
||||
def __init__(
|
||||
self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
self,
|
||||
lang_dir: Path,
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Refer to the help information in Lexicon.__init__.
|
||||
|
@ -225,7 +225,10 @@ def store_transcripts(
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO, test_set_name: str, results: List[Tuple[str, str]]
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
@ -255,6 +258,9 @@ def write_error_stats(
|
||||
results:
|
||||
An iterable of tuples. The first element is the reference transcript
|
||||
while the second element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
@ -290,11 +296,12 @@ def write_error_stats(
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
|
Loading…
x
Reference in New Issue
Block a user