mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Refactor decode.py to make it more readable and more modular. (#44)
* Refactor decode.py to make it more readable and more modular. * Fix an error. Nbest.fsa should always have token IDs as labels and word IDs as aux_labels. * Add nbest decoding. * Compute edit distance with k2. * Refactor nbest-oracle. * Add rescore with nbest lists. * Add whole-lattice rescoring. * Add rescoring with attention decoder. * Refactoring. * Fixes after refactoring. * Fix a typo. * Minor fixes. * Replace [] with () for shapes. * Use k2 v1.9 * Use Levenshtein graphs/alignment from k2 v1.9 * [doc] Require k2 >= v1.9 * Minor fixes.
This commit is contained in:
parent
cc77cb3459
commit
a80e58e15d
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.8]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.8.dev20210917"]
|
||||
k2-version: ["1.9.dev20210919"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -32,7 +32,7 @@ jobs:
|
||||
os: [ubuntu-18.04, macos-10.15]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.8.dev20210917"]
|
||||
k2-version: ["1.9.dev20210919"]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="80" height="20" role="img" aria-label="k2: >= v1.7"><title>k2: >= v1.7</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="80" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="57" height="20" fill="blueviolet"/><rect width="80" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="505" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="470">>= v1.7</text><text x="505" y="140" transform="scale(.1)" fill="#fff" textLength="470">>= v1.7</text></g></svg>
|
Before Width: | Height: | Size: 1.1 KiB |
1
docs/source/installation/images/k2-v1.9-blueviolet.svg
Normal file
1
docs/source/installation/images/k2-v1.9-blueviolet.svg
Normal file
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="58" height="20" role="img" aria-label="k2: v1.9"><title>k2: v1.9</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="58" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="23" height="20" fill="#555"/><rect x="23" width="35" height="20" fill="blueviolet"/><rect width="58" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><text aria-hidden="true" x="125" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="130">k2</text><text x="125" y="140" transform="scale(.1)" fill="#fff" textLength="130">k2</text><text aria-hidden="true" x="395" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="250">v1.9</text><text x="395" y="140" transform="scale(.1)" fill="#fff" textLength="250">v1.9</text></g></svg>
|
After Width: | Height: | Size: 1.1 KiB |
@ -21,7 +21,7 @@ Installation
|
||||
.. |torch_versions| image:: ./images/torch-1.6.0_1.7.0_1.7.1_1.8.0_1.8.1_1.9.0-green.svg
|
||||
:alt: Supported PyTorch versions
|
||||
|
||||
.. |k2_versions| image:: ./images/k2-v-1.7.svg
|
||||
.. |k2_versions| image:: ./images/k2-v1.9-blueviolet.svg
|
||||
:alt: Supported k2 versions
|
||||
|
||||
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
|
||||
@ -40,7 +40,7 @@ to install ``k2``.
|
||||
|
||||
.. CAUTION::
|
||||
|
||||
You need to install ``k2`` with a version at least **v1.7**.
|
||||
You need to install ``k2`` with a version at least **v1.9**.
|
||||
|
||||
.. HINT::
|
||||
|
||||
|
@ -98,7 +98,7 @@ class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
|
@ -213,12 +213,12 @@ def decode_one_batch(
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
@ -244,14 +244,19 @@ def decode_one_batch(
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is slightly worse than that of rescored lattices.
|
||||
return nbest_oracle(
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
@ -264,7 +269,7 @@ def decode_one_batch(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
|
||||
|
||||
@ -288,17 +293,23 @@ def decode_one_batch(
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
@ -308,16 +319,20 @@ def decode_one_batch(
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
for lm_scale in lm_scale_list:
|
||||
ans[lm_scale_str] = [[] * lattice.shape[0]]
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -336,7 +336,7 @@ def main():
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
attention_scale=params.attention_decoder_scale,
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ import torch.nn as nn
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
# On entry, x is [N, T, idim]
|
||||
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
||||
# On entry, x is (N, T, idim)
|
||||
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||
x = self.conv(x)
|
||||
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
||||
# Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
return x
|
||||
|
||||
|
||||
@ -80,8 +80,8 @@ class VggSubsampling(nn.Module):
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
Convert an input of shape (N, T, idim) to an output
|
||||
with shape (N, T', odim), where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||
"""
|
||||
|
||||
@ -93,10 +93,10 @@ class VggSubsampling(nn.Module):
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Input dim. The input shape is (N, T, idim).
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -149,10 +149,10 @@ class VggSubsampling(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
Its shape is (N, T, idim).
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim)
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
|
@ -310,14 +310,14 @@ def compute_loss(
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is [N, T, C]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
|
@ -83,8 +83,8 @@ class Transformer(nn.Module):
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape [N, T, num_classes]
|
||||
# to the shape [N, T//subsampling_factor, d_model].
|
||||
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_classes -> d_model
|
||||
@ -162,7 +162,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is [N, T, C].
|
||||
The input tensor. Its shape is (N, T, C).
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -171,17 +171,17 @@ class Transformer(nn.Module):
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 3 tensors:
|
||||
- CTC output for ctc decoding. Its shape is [N, T, C]
|
||||
- Encoder output with shape [T, N, C]. It can be used as key and
|
||||
- CTC output for ctc decoding. Its shape is (N, T, C)
|
||||
- Encoder output with shape (T, N, C). It can be used as key and
|
||||
value for the decoder.
|
||||
- Encoder output padding mask. It can be used as
|
||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T]
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
@ -195,7 +195,7 @@ class Transformer(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
@ -206,8 +206,8 @@ class Transformer(nn.Module):
|
||||
padding mask for the decoder.
|
||||
Returns:
|
||||
Return a tuple with two tensors:
|
||||
- The encoder output, with shape [T, N, C]
|
||||
- encoder padding mask, with shape [N, T].
|
||||
- The encoder output, with shape (T, N, C)
|
||||
- encoder padding mask, with shape (N, T).
|
||||
The mask is None if `supervisions` is None.
|
||||
It is used as memory key padding mask in the decoder.
|
||||
"""
|
||||
@ -225,11 +225,11 @@ class Transformer(nn.Module):
|
||||
Args:
|
||||
x:
|
||||
The output tensor from the transformer encoder.
|
||||
Its shape is [T, N, C]
|
||||
Its shape is (T, N, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor that can be used for CTC decoding.
|
||||
Its shape is [N, T, C]
|
||||
Its shape is (N, T, C)
|
||||
"""
|
||||
x = self.encoder_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -247,7 +247,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
@ -312,7 +312,7 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape [T, N, C]
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
@ -654,13 +654,13 @@ class PositionalEncoding(nn.Module):
|
||||
def extend_pe(self, x: torch.Tensor) -> None:
|
||||
"""Extend the time t in the positional encoding if required.
|
||||
|
||||
The shape of `self.pe` is [1, T1, d_model]. The shape of the input x
|
||||
is [N, T, d_model]. If T > T1, then we change the shape of self.pe
|
||||
to [N, T, d_model]. Otherwise, nothing is done.
|
||||
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
||||
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
||||
to (N, T, d_model). Otherwise, nothing is done.
|
||||
|
||||
Args:
|
||||
x:
|
||||
It is a tensor of shape [N, T, C].
|
||||
It is a tensor of shape (N, T, C).
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
@ -678,7 +678,7 @@ class PositionalEncoding(nn.Module):
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
# Now pe is of shape [1, T, d_model], where T is x.size(1)
|
||||
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -687,10 +687,10 @@ class PositionalEncoding(nn.Module):
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, C]
|
||||
Its shape is (N, T, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, T, C]
|
||||
Return a tensor of shape (N, T, C)
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
||||
|
@ -190,12 +190,12 @@ def decode_one_batch(
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
|
||||
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
@ -229,6 +229,7 @@ def decode_one_batch(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
key = f"no_rescore-{params.num_paths}"
|
||||
hyps = get_texts(best_path)
|
||||
@ -247,10 +248,13 @@ def decode_one_batch(
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
lattice_score_scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
|
||||
ans = dict()
|
||||
|
@ -218,11 +218,11 @@ def main():
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
features = features.permute(0, 2, 1) # now features is [N, C, T]
|
||||
features = features.permute(0, 2, 1) # now features is (N, C, T)
|
||||
|
||||
with torch.no_grad():
|
||||
nnet_output = model(features)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
|
@ -290,14 +290,14 @@ def compute_loss(
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is [N, T, C]
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
# at entry, feature is (N, T, C)
|
||||
feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
|
@ -111,10 +111,10 @@ def decode_one_batch(
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
|
@ -268,13 +268,13 @@ def compute_loss(
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is [N, T, C]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is [N, T, C]
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
|
1065
icefall/decode.py
1065
icefall/decode.py
File diff suppressed because it is too large
Load Diff
@ -106,7 +106,7 @@ class CtcTrainingGraphCompiler(object):
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split(" "):
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
|
@ -186,7 +186,9 @@ def encode_supervisions(
|
||||
return supervision_segments, texts
|
||||
|
||||
|
||||
def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
def get_texts(
|
||||
best_paths: k2.Fsa, return_ragged: bool = False
|
||||
) -> Union[List[List[int]], k2.RaggedTensor]:
|
||||
"""Extract the texts (as word IDs) from the best-path FSAs.
|
||||
Args:
|
||||
best_paths:
|
||||
@ -194,6 +196,9 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
containing multiple FSAs, which is expected to be the result
|
||||
of k2.shortest_path (otherwise the returned values won't
|
||||
be meaningful).
|
||||
return_ragged:
|
||||
True to return a ragged tensor with two axes [utt][word_id].
|
||||
False to return a list-of-list word IDs.
|
||||
Returns:
|
||||
Returns a list of lists of int, containing the label sequences we
|
||||
decoded.
|
||||
@ -216,7 +221,10 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
aux_labels = aux_labels.remove_values_leq(0)
|
||||
|
||||
assert aux_labels.num_axes == 2
|
||||
return aux_labels.tolist()
|
||||
if return_ragged:
|
||||
return aux_labels
|
||||
else:
|
||||
return aux_labels.tolist()
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
|
62
test/test_decode.py
Normal file
62
test/test_decode.py
Normal file
@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
You can run this file in one of the two ways:
|
||||
|
||||
(1) cd icefall; pytest test/test_decode.py
|
||||
(2) cd icefall; ./test/test_decode.py
|
||||
"""
|
||||
|
||||
import k2
|
||||
from icefall.decode import Nbest
|
||||
|
||||
|
||||
def test_nbest_from_lattice():
|
||||
s = """
|
||||
0 1 1 10 0.1
|
||||
0 1 5 10 0.11
|
||||
0 1 2 20 0.2
|
||||
1 2 3 30 0.3
|
||||
1 2 4 40 0.4
|
||||
2 3 -1 -1 0.5
|
||||
3
|
||||
"""
|
||||
lattice = k2.Fsa.from_str(s, acceptor=False)
|
||||
lattice = k2.Fsa.from_fsas([lattice, lattice])
|
||||
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=10,
|
||||
use_double_scores=True,
|
||||
lattice_score_scale=0.5,
|
||||
)
|
||||
# each lattice has only 4 distinct paths that have different word sequences:
|
||||
# 10->30
|
||||
# 10->40
|
||||
# 20->30
|
||||
# 20->40
|
||||
#
|
||||
# So there should be only 4 paths for each lattice in the Nbest object
|
||||
assert nbest.fsa.shape[0] == 4 * 2
|
||||
assert nbest.shape.row_splits(1).tolist() == [0, 4, 8]
|
||||
|
||||
nbest2 = nbest.intersect(lattice)
|
||||
tot_scores = nbest2.tot_scores()
|
||||
argmax = tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest2.fsa, argmax)
|
||||
print(best_path[0])
|
Loading…
x
Reference in New Issue
Block a user