Tedlium3 conformer ctc2 (#696)

* modify preparation

* small refacor

* add tedlium3 conformer_ctc2

* modify decode

* filter unk in decode

* add scaling converter

* address comments

* fix lambda function lhotse

* add implicit manifest shuffle

* refactor ctc_greedy_search

* import model arguments from train.py

* style fix

* fix ci test and last style issues

* update RESULTS

* fix RESULTS numbers

* fix label smoothing loss

* update model parameters number in RESULTS
This commit is contained in:
Daniil 2022-12-13 03:13:26 -05:00 committed by GitHub
parent 0470bbae66
commit b293db4baf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 5158 additions and 215 deletions

View File

@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
assert reduction in ("none", "sum", "mean"), reduction
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction

View File

@ -24,10 +24,9 @@ from scaling import (
ScaledConv2d,
ScaledLinear,
)
from torch import nn
class Conv2dSubsampling(nn.Module):
class Conv2dSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Convert an input of shape (N, T, idim) to an output
@ -61,7 +60,7 @@ class Conv2dSubsampling(nn.Module):
assert in_channels >= 7
super().__init__()
self.conv = nn.Sequential(
self.conv = torch.nn.Sequential(
ScaledConv2d(
in_channels=1,
out_channels=layer1_channels,

View File

@ -1435,7 +1435,7 @@ class EmformerEncoder(nn.Module):
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor],]:
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward pass for streaming inference.
B: batch size;
@ -1640,7 +1640,7 @@ class Emformer(EncoderInterface):
self,
x: torch.Tensor,
states: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor],]:
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward pass for streaming inference.
B: batch size;

View File

@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt
- G, the LM, built from data/lm/G_n_gram.fst.txt
The generated HLG is saved in $lang_dir/HLG.pt
"""

View File

@ -28,7 +28,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -41,6 +41,10 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def is_cut_long(c: MonoCut) -> bool:
return c.duration > 5
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
@ -86,7 +90,7 @@ def compute_fbank_musan():
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.filter(is_cut_long)
.compute_and_store_features(
extractor=extractor,
storage_path=f"{output_dir}/musan_feats",

View File

@ -127,7 +127,7 @@ def lexicon_to_fst_no_sil(
def generate_lexicon(
model_file: str, words: List[str]
model_file: str, words: List[str], oov: str
) -> Tuple[Lexicon, Dict[str, int]]:
"""Generate a lexicon from a BPE model.
@ -136,6 +136,8 @@ def generate_lexicon(
Path to a sentencepiece model.
words:
A list of strings representing words.
oov:
The out of vocabulary word in lexicon.
Returns:
Return a tuple with two elements:
- A dict whose keys are words and values are the corresponding
@ -156,12 +158,9 @@ def generate_lexicon(
for word, pieces in zip(words, words_pieces):
lexicon.append((word, pieces))
# The OOV word is <UNK>
lexicon.append(("<UNK>", [sp.id_to_piece(sp.unk_id())]))
lexicon.append((oov, ["", sp.id_to_piece(sp.unk_id())]))
token2id: Dict[str, int] = dict()
for i in range(sp.vocab_size()):
token2id[sp.id_to_piece(i)] = i
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
return lexicon, token2id
@ -176,6 +175,13 @@ def get_args():
""",
)
parser.add_argument(
"--oov",
type=str,
default="<UNK>",
help="The out of vocabulary word in lexicon.",
)
parser.add_argument(
"--debug",
type=str2bool,
@ -202,12 +208,13 @@ def main():
words = word_sym_table.symbols
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
for w in excluded:
if w in words:
words.remove(w)
lexicon, token_sym_table = generate_lexicon(model_file, words)
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)

View File

@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if random.random() >= self.balance_prob:
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):

View File

@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled(
if not inplace:
model = copy.deepcopy(model)
excluded_patterns = r"self_attn\.(in|out)_proj"
excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
p = re.compile(excluded_patterns)
d = {}

View File

@ -1,5 +1,88 @@
## Results
### TedLium3 BPE training results (Conformer-CTC 2)
#### [conformer_ctc2](./conformer_ctc2)
See <https://github.com/k2-fsa/icefall/pull/696> for more details.
The tensorboard log can be found at
<https://tensorboard.dev/experiment/5NQQiqOqSqazfn4w2yeWEQ/>
You can find a pretrained model and decoding results at:
<https://huggingface.co/videodanchik/icefall-asr-tedlium3-conformer-ctc2>
Number of model parameters: 101141699, i.e., 101.14 M
The WERs are
| | dev | test | comment |
|--------------------------|------------|-------------|---------------------|
| ctc decoding | 6.45 | 5.96 | --epoch 38 --avg 26 |
| 1best | 5.92 | 5.51 | --epoch 38 --avg 26 |
| whole lattice rescoring | 5.96 | 5.47 | --epoch 38 --avg 26 |
| attention decoder | 5.60 | 5.33 | --epoch 38 --avg 26 |
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./conformer_ctc2/train.py \
--world-size 4 \
--num-epochs 40 \
--exp-dir conformer_ctc2/exp \
--max-duration 350 \
--use-fp16 true
```
The decoding command is:
```
epoch=38
avg=26
## ctc decoding
./conformer_ctc2/decode.py \
--method ctc-decoding \
--exp-dir conformer_ctc2/exp \
--lang-dir data/lang_bpe_500 \
--result-dir conformer_ctc2/exp \
--max-duration 500 \
--epoch $epoch \
--avg $avg
## 1best
./conformer_ctc2/decode.py \
--method 1best \
--exp-dir conformer_ctc2/exp \
--lang-dir data/lang_bpe_500 \
--result-dir conformer_ctc2/exp \
--max-duration 500 \
--epoch $epoch \
--avg $avg
## whole lattice rescoring
./conformer_ctc2/decode.py \
--method whole-lattice-rescoring \
--exp-dir conformer_ctc2/exp \
--lm-path data/lm/G_4_gram_big.pt \
--lang-dir data/lang_bpe_500 \
--result-dir conformer_ctc2/exp \
--max-duration 500 \
--epoch $epoch \
--avg $avg
## attention decoder
./conformer_ctc2/decode.py \
--method attention-decoder \
--exp-dir conformer_ctc2/exp \
--lang-dir data/lang_bpe_500 \
--result-dir conformer_ctc2/exp \
--max-duration 500 \
--epoch $epoch \
--avg $avg
```
### TedLium3 BPE training results (Pruned Transducer)
#### 2022-03-21

View File

View File

@ -0,0 +1 @@
../transducer_stateless/asr_datamodule.py

View File

@ -0,0 +1,201 @@
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from scaling import ScaledLinear
class MultiheadAttention(torch.nn.Module):
"""Allows the model to jointly attend to information
from different representation subspaces. This is a modified
version of the original version of multihead attention
(see Attention Is All You Need <https://arxiv.org/abs/1706.03762>)
with replacement of input / output projection layers
with newly introduced ScaleLinear layer
(see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py).
Args:
embed_dim:
total dimension of the model.
num_heads:
number of parallel attention heads. Note that embed_dim will be split
across num_heads, i.e. each head will have dimension (embed_dim // num_heads).
dropout:
dropout probability on attn_output_weights. (default=0.0).
bias:
if specified, adds bias to input / output projection layers (default=True).
add_bias_kv:
if specified, adds bias to the key and value sequences at dim=0 (default=False).
add_zero_attn:
if specified, adds a new batch of zeros to the key and value sequences
at dim=1 (default=False).
batch_first:
if True, then the input and output tensors are provided as
(batch, seq, feature), otherwise (seq, batch, feature) (default=False).
Examples::
>>> multihead_attn = MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
batch_first: bool = False,
device: Union[torch.device, str, None] = None,
dtype: Union[torch.dtype, str, None] = None,
) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim must be divisible by num_heads. "
"Got embedding dim vs number 0f heads: "
f"{embed_dim} vs {num_heads}"
)
self.head_dim = embed_dim // num_heads
self.in_proj = ScaledLinear(
embed_dim,
3 * embed_dim,
bias=bias,
device=device,
dtype=dtype,
)
self.out_proj = ScaledLinear(
embed_dim,
embed_dim,
bias=bias,
initial_scale=0.25,
device=device,
dtype=dtype,
)
if add_bias_kv:
self.bias_k = torch.nn.Parameter(
torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
)
self.bias_v = torch.nn.Parameter(
torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
)
else:
self.register_parameter("bias_k", None)
self.register_parameter("bias_v", None)
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self) -> None:
if self.bias_k is not None:
torch.nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
torch.nn.init.xavier_normal_(self.bias_v)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
query:
Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q)
when batch_first=True, where L is the target sequence length, N is the batch size,
and E_q is the query embedding dimension embed_dim. Queries are compared against
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
key:
Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when
batch_first=True, where S is the source sequence length, N is the batch size, and
E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details.
value:
Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when
batch_first=True, where S is the source sequence length, N is the batch size, and
E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details.
key_padding_mask:
If specified, a mask of shape (N, S) indicating which elements within key
to ignore for the purpose of attention (i.e. treat as "padding").
Binary and byte masks are supported. For a binary mask, a True value indicates
that the corresponding key value will be ignored for the purpose of attention.
For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
need_weights:
If specifid, returns attn_output_weights in addition to attn_outputs (default=True).
attn_mask:
If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
(L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length,
and S is the source sequence length. A 2D mask will be broadcasted across the batch while
a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a True value indicates
that the corresponding position is not allowed to attend. For a byte mask, a non-zero
value indicates that the corresponding position is not allowed to attend. For a float mask,
the mask values will be added to the attention weight.
Returns:
attn_output:
Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True,
where L is the target sequence length, N is the batch size, and E is the embedding dimension
embed_dim.
attn_output_weights:
Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence
length, and S is the source sequence length. Only returned when need_weights=True.
"""
if self.batch_first:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
(
attn_output,
attn_output_weights,
) = torch.nn.functional.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
in_proj_weight=self.in_proj.get_weight(),
in_proj_bias=self.in_proj.get_bias(),
bias_k=self.bias_k,
bias_v=self.bias_v,
add_zero_attn=self.add_zero_attn,
dropout_p=self.dropout,
out_proj_weight=self.out_proj.get_weight(),
out_proj_bias=self.out_proj.get_bias(),
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
if self.batch_first:
return attn_output.transpose(1, 0), attn_output_weights
return attn_output, attn_output_weights

View File

@ -0,0 +1,244 @@
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
class RandomCombine(torch.nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1, num_inputs
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1))
.log()
.item()
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}"
if not self.training or torch.jit.is_scripting() or len(inputs) == 1:
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(
dtype=dtype
)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(
final_weight: float,
pure_prob: float,
stddev: float,
) -> None:
print(
f"_test_random_combine: final_weight={final_weight}, "
f"pure_prob={pure_prob}, stddev={stddev}"
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main() -> None:
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
if __name__ == "__main__":
_test_random_combine_main()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,899 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Quandong Wang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import shutil
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import TedLiumAsrDataModule
from conformer import Conformer
from train import add_model_arguments
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
load_averaged_model,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--method",
type=str,
default="attention-decoder",
help="""Decoding method.
Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
model for decoding. It produces the same results with ctc-decoding.
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
- (6) attention-decoder. Extract n paths from the LM rescored
lattice, the path with the highest score is the decoding result.
- (7) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc2/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--lm-path",
type=str,
default="data/lm/G_4_gram.pt",
help="""The n-gram LM dir for rescoring.
It should contain either lm_fname.pt or lm_fname.fst.txt
""",
)
parser.add_argument(
"--result-dir",
type=str,
default="conformer_ctc2/exp",
help="Directory to store results.",
)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
"""
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
# parameters for decoding
"search_beam": 15,
"output_beam": 8,
"min_active_states": 10,
"max_active_states": 7000,
"use_double_scores": True,
"env_info": get_env_info(),
}
)
return params
def ctc_greedy_search(
ctc_probs: torch.Tensor,
mask: torch.Tensor,
) -> List[List[int]]:
"""Apply CTC greedy search
Args:
ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
mask (torch.Tensor): (batch, max_len)
Returns:
best path result
"""
_, max_index = ctc_probs.max(2) # (B, maxlen)
max_index = max_index.masked_fill_(mask, 0) # (B, maxlen)
ret_hyps = []
for hyp in max_index:
hyp = torch.unique_consecutive(hyp)
hyp = hyp[hyp > 0].tolist()
ret_hyps.append(hyp)
return ret_hyps
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
torch.div(
supervisions["start_frame"],
params.subsampling_factor,
rounding_mode="floor",
),
torch.div(
supervisions["num_frames"],
params.subsampling_factor,
rounding_mode="floor",
),
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
unk = bpe_model.decode(bpe_model.unk_id()).strip()
hyps = [[w for w in s.split() if w != unk] for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "ctc-greedy-search":
hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(hyps)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
unk = bpe_model.decode(bpe_model.unk_id()).strip()
hyps = [[w for w in s.split() if w != unk] for s in hyps]
key = "ctc-greedy-search"
return {key: hyps}
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<unk>",
)
hyps = get_texts(best_path)
hyps = [
[word_table[i] for i in ids if word_table[i] != "<unk>"] for ids in hyps
]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.method == "nbest":
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [
[word_table[i] for i in ids if word_table[i] != "<unk>"] for ids in hyps
]
return {key: hyps}
assert params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "1best":
best_path_dict = one_best_decoding(
lattice=lattice,
lm_scale_list=lm_scale_list,
)
elif params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.method == "attention-decoder":
best_path_dict = rescore_with_attention_decoder(
lattice=lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
nbest_scale=params.nbest_scale,
)
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [
[word_table[i] for i in ids if word_table[i] != "<unk>"] for ids in hyps
]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
if hyps_dict is not None:
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch)
else:
assert len(results) > 0, "It should not decode to empty in the first batch!"
this_batch = []
hyp_words = []
for ref_text in texts:
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
for lm_scale in results.keys():
results[lm_scale].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
) -> None:
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt"
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main() -> None:
parser = get_parser()
TedLiumAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_path = Path(args.lm_path)
args.result_dir = Path(args.result_dir)
if args.result_dir.is_dir():
shutil.rmtree(args.result_dir)
args.result_dir.mkdir()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
if params.method in ("ctc-decoding", "ctc-greedy-search"):
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in ("nbest-rescoring", "whole-lattice-rescoring"):
assert params.lm_path.suffix in (".pt", ".txt")
if params.lm_path.is_file() and params.lm_path.suffix == ".pt":
logging.info(f"Loading pre-compiled {params.lm_path.name}")
d = torch.load(params.lm_path, map_location=device)
G = k2.Fsa.from_dict(d)
elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt":
raise FileNotFoundError(f"No such language model file: '{params.lm_path}'")
else:
# here we pass only if LM filename ends with '.pt' and doesn't exist
# or if LM filename ends '.txt' and exists.
if (
not params.lm_path.is_file()
and params.lm_path.suffix == ".pt"
and not (
params.lm_path.parent / f"{params.lm_path.stem}.fst.txt"
).is_file()
):
raise FileNotFoundError(
f"No such language model file: '{params.lm_path}'\n"
"'.fst.txt' representation of the language model was "
"not found either."
)
else:
# whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt
# we are going to load lm_name.fst.txt here
params.lm_path = params.lm_path.parent / params.lm_path.name.replace(
".pt", ".fst.txt"
)
logging.info(f"Loading {params.lm_path.name}")
logging.warning("It may take 8 minutes.")
with open(params.lm_path) as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(
G.as_dict(),
params.lm_path.parent
/ params.lm_path.name.replace(".fst.txt", ".pt"),
)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Conformer(
num_features=params.feature_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
d_model=params.dim_model,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
num_decoder_layers=params.num_decoder_layers,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
tedlium = TedLiumAsrDataModule(args)
valid_cuts = tedlium.dev_cuts()
test_cuts = tedlium.test_cuts()
valid_dl = tedlium.valid_dataloaders(valid_cuts)
test_dl = tedlium.test_dataloaders(test_cuts)
test_sets = ["dev", "test"]
test_dls = [valid_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
logging.info("Done!")
torch.set_num_threads(1)
# when we import add_model_arguments from train.py
# we enforce torch.set_num_interop_threads(1) in it,
# so we ended up with setting num_interop_threads to one
# two times: in train.py and decode.py which cause an error,
# that is why added an additional if statement.
if torch.get_num_interop_threads() != 1:
torch.set_num_interop_threads(1)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
if __name__ == "__main__":
main()

View File

@ -0,0 +1,294 @@
#!/usr/bin/env python3
#
# Copyright 2022 Behavox LLC (Author: Daniil Kulko)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
./conformer_ctc2/export.py \
--exp-dir ./conformer_ctc2/exp \
--epoch 20 \
--avg 10
It will generate a file exp_dir/pretrained.pt
To use the generated file with `conformer_ctc2/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/tedlium3/ASR
./conformer_ctc2/decode.py \
--exp-dir ./conformer_ctc2/exp \
--epoch 9999 \
--avg 1 \
--max-duration 100
"""
import argparse
import logging
from pathlib import Path
import torch
from conformer import Conformer
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict, str2bool
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help=(
"Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'"
),
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help=(
"Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. "
),
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc2/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="The lang dir",
)
parser.add_argument(
"--jit",
type=str2bool,
default=True,
help="""True to save a model after applying torch.jit.script.
""",
)
add_model_arguments(parser)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
"""
# parameters for conformer
params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80})
return params
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info(params)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
d_model=params.dim_model,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
num_decoder_layers=params.num_decoder_layers,
)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit:
convert_scaled_to_non_scaled(model, inplace=True)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torch.jit.script")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc2/subsampling.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -4,16 +4,18 @@
"""
Convert a transcript based on words to a list of BPE ids.
For example, if we use 2 as the encoding id of <unk>:
For example, if we use 2 as the encoding id of <unk>
Note: it, inserts a space token before each <unk>
texts = ['this is a <unk> day']
spm_ids = [[38, 33, 6, 2, 316]]
spm_ids = [[38, 33, 6, 15, 2, 316]]
texts = ['<unk> this is a sunny day']
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]]
spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]]
texts = ['<unk>']
spm_ids = [[2]]
spm_ids = [[15, 2]]
"""
import argparse
@ -38,29 +40,27 @@ def get_args():
def convert_texts_into_ids(
texts: List[str],
unk_id: int,
sp: spm.SentencePieceProcessor,
) -> List[List[int]]:
"""
Args:
texts:
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
unk_id:
A number id for the token '<unk>'.
sp:
A sentencepiece BPE model.
Returns:
Return an integer list of bpe ids.
"""
y = []
for text in texts:
y_ids = []
if "<unk>" in text:
text_segments = text.split("<unk>")
id_segments = sp.encode(text_segments, out_type=int)
id_segments = sp.encode(text.split("<unk>"), out_type=int)
y_ids = []
for i in range(len(id_segments)):
if i != len(id_segments) - 1:
y_ids.extend(id_segments[i] + [unk_id])
else:
y_ids.extend(id_segments[i])
y_ids += id_segments[i]
if i < len(id_segments) - 1:
y_ids += [sp.piece_to_id(""), sp.unk_id()]
else:
y_ids = sp.encode(text, out_type=int)
y.append(y_ids)
@ -70,19 +70,13 @@ def convert_texts_into_ids(
def main():
args = get_args()
texts = args.texts
bpe_model = args.bpe_model
sp = spm.SentencePieceProcessor()
sp.load(bpe_model)
unk_id = sp.piece_to_id("<unk>")
sp.load(args.bpe_model)
y = convert_texts_into_ids(
texts=texts,
unk_id=unk_id,
sp=sp,
)
logging.info(f"The input texts: {texts}")
y = convert_texts_into_ids(texts=args.texts, sp=sp)
logging.info(f"The input texts: {args.texts}")
logging.info(f"The encoding ids: {y}")

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/generate_unique_lexicon.py

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/prepare_lang.py

View File

@ -1,94 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input supervisions json dir "data/manifests"
consisting of supervisions_train.json and does the following:
1. Generate lexicon_words.txt.
"""
import argparse
import logging
from pathlib import Path
import lhotse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifests-dir",
type=str,
help="""Input directory.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Output directory.
""",
)
return parser.parse_args()
def prepare_lexicon(manifests_dir: str, lang_dir: str):
"""
Args:
manifests_dir:
The manifests directory, e.g., data/manifests.
lang_dir:
The language directory, e.g., data/lang_phone.
Return:
The lexicon_words.txt file.
"""
words = set()
lexicon = Path(lang_dir) / "lexicon_words.txt"
sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
for s in sups:
# list the words units and filter the empty item
words_list = list(filter(None, s.text.split()))
for word in words_list:
if word not in words and word != "<unk>":
words.add(word)
with open(lexicon, "w") as f:
for word in sorted(words):
f.write(word + " " + word)
f.write("\n")
def main():
args = get_args()
manifests_dir = Path(args.manifests_dir)
lang_dir = Path(args.lang_dir)
logging.info("Generating lexicon_words.txt")
prepare_lexicon(manifests_dir, lang_dir)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
# Copyright 2021 Xiaomi Corp. (author: Mingshuang Luo)
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -17,68 +18,67 @@
"""
This script takes as input supervisions json dir "data/manifests"
consisting of supervisions_train.json and does the following:
1. Generate train.text.
This script takes input text file and removes all words
that iclude any character out of English alphabet.
"""
import argparse
import logging
import re
from pathlib import Path
import lhotse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifests-dir",
"--input-text-path",
type=str,
help="""Input directory.
""",
help="Input text file path.",
)
parser.add_argument(
"--lang-dir",
"--output-text-path",
type=str,
help="""Output directory.
""",
help="Output text file path.",
)
return parser.parse_args()
def prepare_transcripts(manifests_dir: str, lang_dir: str):
def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None:
"""
Args:
manifests_dir:
The manifests directory, e.g., data/manifests.
lang_dir:
The language directory, e.g., data/lang_phone.
input_text_path:
The input data text file path, e.g., data/lang/train_orig.txt.
output_text_path:
The output data text file path, e.g., data/lang/train.txt.
Return:
The train.text in lang_dir.
Saved text file in output_text_path.
"""
texts = []
train_text = Path(lang_dir) / "train.text"
sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
for s in sups:
texts.append(s.text)
foreign_chr_check = re.compile(r"[^a-z']")
with open(train_text, "w") as f:
for text in texts:
f.write(text)
f.write("\n")
logging.info(f"Loading {input_text_path.name}")
with open(input_text_path, "r", encoding="utf8") as f:
texts = {t.rstrip("\n") for t in f}
texts = {
" ".join([w for w in t.split() if foreign_chr_check.search(w) is None])
for t in texts
}
with open(output_text_path, "w+", encoding="utf8") as f:
for t in texts:
f.write(f"{t}\n")
def main():
def main() -> None:
args = get_args()
manifests_dir = Path(args.manifests_dir)
lang_dir = Path(args.lang_dir)
input_text_path = Path(args.input_text_path)
output_text_path = Path(args.output_text_path)
logging.info("Generating train.text")
prepare_transcripts(manifests_dir, lang_dir)
logging.info(f"Generating {output_text_path.name}")
prepare_transcripts(input_text_path, output_text_path)
if __name__ == "__main__":

View File

@ -0,0 +1,83 @@
#!/usr/bin/env python3
# Copyright 2022 Behavox LLC. (authors: Daniil Kulko)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script takes as input supervisions json dir "data/manifests"
consisting of tedlium_supervisions_train.json and does the following:
1. Generate words.txt.
"""
import argparse
import logging
import re
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="Output directory.",
)
return parser.parse_args()
def prepare_words(lang_dir: str) -> None:
"""
Args:
lang_dir:
The language directory, e.g., data/lang.
Return:
The words.txt file.
"""
words_orig_path = Path(lang_dir) / "words_orig.txt"
words_path = Path(lang_dir) / "words.txt"
foreign_chr_check = re.compile(r"[^a-z']")
logging.info(f"Loading {words_orig_path.name}")
with open(words_orig_path, "r", encoding="utf8") as f:
words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")}
words = {w for w in words if foreign_chr_check.search(w) is None and w != ""}
words.add("<unk>")
words = ["<eps>", "!SIL"] + sorted(words) + ["#0", "<s>", "</s>"]
with open(words_path, "w+", encoding="utf8") as f:
for idx, word in enumerate(words):
f.write(f"{word} {idx}\n")
def main() -> None:
args = get_args()
lang_dir = Path(args.lang_dir)
logging.info("Generating words.txt")
prepare_words(lang_dir)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/local/test_prepare_lang.py

View File

@ -5,7 +5,6 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=0
stop_stage=100
@ -63,6 +62,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3
fi
# Download big and small 4 gram lanuage models
if [ ! -d $dl_dir/lm ]; then
wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm
wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm
gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
@ -100,7 +106,14 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
if [ ! -e data/fbank/.tedlium3.done ]; then
mkdir -p data/fbank
python3 ./local/compute_fbank_tedlium.py
gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \
gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz
mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \
data/fbank/tedlium_cuts_train.jsonl.gz
touch data/fbank/.tedlium3.done
fi
fi
@ -115,28 +128,24 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Prepare phone based lang"
lang_dir=data/lang_phone
log "Stage 5: Prepare BPE train data and set of words"
lang_dir=data/lang
mkdir -p $lang_dir
if [ ! -f $lang_dir/train.text ]; then
if [ ! -f $lang_dir/train.txt ]; then
gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt
./local/prepare_transcripts.py \
--lang-dir $lang_dir \
--manifests-dir data/manifests
--input-text-path $lang_dir/train_orig.txt \
--output-text-path $lang_dir/train.txt
fi
if [ ! -f $lang_dir/lexicon_words.txt ]; then
./local/prepare_lexicon.py \
--lang-dir $lang_dir \
--manifests-dir data/manifests
fi
if [ ! -f $lang_dir/words.txt ]; then
(echo '!SIL SIL'; echo '<UNK> <UNK>'; ) |
cat - $lang_dir/lexicon_words.txt |
sort | uniq > $lang_dir/lexicon.txt
awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic |
sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang.py --lang-dir $lang_dir
./local/prepare_words.py --lang-dir $lang_dir
fi
fi
@ -148,25 +157,56 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
mkdir -p $lang_dir
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt $lang_dir
if [ ! -f $lang_dir/transcript_words.txt ]; then
log "Generate data for BPE training"
cat data/lang_phone/train.text |
cut -d " " -f 2- > $lang_dir/transcript_words.txt
# remove the <unk> for transcript_words.txt
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt
fi
cp data/lang/words.txt $lang_dir
./local/train_bpe_model.py \
--lang-dir $lang_dir \
--vocab-size $vocab_size \
--transcript $lang_dir/transcript_words.txt
--transcript data/lang/train.txt
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov "<unk>"
fi
done
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Prepare G"
# We assume you have install kaldilm, if not, please install
# it using: pip install kaldilm
mkdir -p data/lm
if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then
# It is used in building HLG
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
--max-arpa-warnings=-1 \
$dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt
fi
if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then
# It is used for LM rescoring
python3 -m kaldilm \
--read-symbol-table="data/lang/words.txt" \
--disambig-symbol='#0' \
--max-order=4 \
--max-arpa-warnings=-1 \
$dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt
fi
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG"
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
if [ ! -f $lang_dir/HLG.pt ]; then
./local/compile_hlg.py \
--lang-dir $lang_dir \
--lm G_4_gram_small
fi
done
fi

View File

@ -466,9 +466,7 @@ def one_best_decoding(
Return:
An FsaVec containing linear paths.
"""
if lm_scale_list is not None:
ans = dict()
saved_am_scores = lattice.scores - lattice.lm_scores
for lm_scale in lm_scale_list:

View File

@ -112,7 +112,7 @@ def uniq_lexicon_test():
# But there is no word "ca" in the lexicon, so our
# implementation returns the id of "<UNK>"
print(token_ids, expected_token_ids)
assert token_ids.tolist() == [[sp.unk_id()]]
assert token_ids.tolist() == [[sp.piece_to_id(""), sp.unk_id()]]
# case 3: With OOV
texts = ["foo"]