mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
Tedlium3 conformer ctc2 (#696)
* modify preparation * small refacor * add tedlium3 conformer_ctc2 * modify decode * filter unk in decode * add scaling converter * address comments * fix lambda function lhotse * add implicit manifest shuffle * refactor ctc_greedy_search * import model arguments from train.py * style fix * fix ci test and last style issues * update RESULTS * fix RESULTS numbers * fix label smoothing loss * update model parameters number in RESULTS
This commit is contained in:
parent
0470bbae66
commit
b293db4baf
@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module):
|
|||||||
mean of the output is taken. (3) "sum": the output will be summed.
|
mean of the output is taken. (3) "sum": the output will be summed.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert 0.0 <= label_smoothing < 1.0
|
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
|
||||||
|
assert reduction in ("none", "sum", "mean"), reduction
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.label_smoothing = label_smoothing
|
self.label_smoothing = label_smoothing
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
@ -24,10 +24,9 @@ from scaling import (
|
|||||||
ScaledConv2d,
|
ScaledConv2d,
|
||||||
ScaledLinear,
|
ScaledLinear,
|
||||||
)
|
)
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class Conv2dSubsampling(nn.Module):
|
class Conv2dSubsampling(torch.nn.Module):
|
||||||
"""Convolutional 2D subsampling (to 1/4 length).
|
"""Convolutional 2D subsampling (to 1/4 length).
|
||||||
|
|
||||||
Convert an input of shape (N, T, idim) to an output
|
Convert an input of shape (N, T, idim) to an output
|
||||||
@ -61,7 +60,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
assert in_channels >= 7
|
assert in_channels >= 7
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.conv = nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
ScaledConv2d(
|
ScaledConv2d(
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=layer1_channels,
|
out_channels=layer1_channels,
|
||||||
|
@ -1435,7 +1435,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
states: List[torch.Tensor],
|
states: List[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, List[torch.Tensor],]:
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
@ -1640,7 +1640,7 @@ class Emformer(EncoderInterface):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
states: List[torch.Tensor],
|
states: List[torch.Tensor],
|
||||||
) -> Tuple[torch.Tensor, List[torch.Tensor],]:
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||||
"""Forward pass for streaming inference.
|
"""Forward pass for streaming inference.
|
||||||
|
|
||||||
B: batch size;
|
B: batch size;
|
||||||
|
@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from
|
|||||||
|
|
||||||
Caution: We use a lexicon that contains disambiguation symbols
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
- G, the LM, built from data/lm/G_n_gram.fst.txt
|
||||||
|
|
||||||
The generated HLG is saved in $lang_dir/HLG.pt
|
The generated HLG is saved in $lang_dir/HLG.pt
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +28,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine
|
from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor
|
||||||
@ -41,6 +41,10 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cut_long(c: MonoCut) -> bool:
|
||||||
|
return c.duration > 5
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_musan():
|
def compute_fbank_musan():
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
@ -86,7 +90,7 @@ def compute_fbank_musan():
|
|||||||
recordings=combine(part["recordings"] for part in manifests.values())
|
recordings=combine(part["recordings"] for part in manifests.values())
|
||||||
)
|
)
|
||||||
.cut_into_windows(10.0)
|
.cut_into_windows(10.0)
|
||||||
.filter(lambda c: c.duration > 5)
|
.filter(is_cut_long)
|
||||||
.compute_and_store_features(
|
.compute_and_store_features(
|
||||||
extractor=extractor,
|
extractor=extractor,
|
||||||
storage_path=f"{output_dir}/musan_feats",
|
storage_path=f"{output_dir}/musan_feats",
|
||||||
|
@ -127,7 +127,7 @@ def lexicon_to_fst_no_sil(
|
|||||||
|
|
||||||
|
|
||||||
def generate_lexicon(
|
def generate_lexicon(
|
||||||
model_file: str, words: List[str]
|
model_file: str, words: List[str], oov: str
|
||||||
) -> Tuple[Lexicon, Dict[str, int]]:
|
) -> Tuple[Lexicon, Dict[str, int]]:
|
||||||
"""Generate a lexicon from a BPE model.
|
"""Generate a lexicon from a BPE model.
|
||||||
|
|
||||||
@ -136,6 +136,8 @@ def generate_lexicon(
|
|||||||
Path to a sentencepiece model.
|
Path to a sentencepiece model.
|
||||||
words:
|
words:
|
||||||
A list of strings representing words.
|
A list of strings representing words.
|
||||||
|
oov:
|
||||||
|
The out of vocabulary word in lexicon.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple with two elements:
|
Return a tuple with two elements:
|
||||||
- A dict whose keys are words and values are the corresponding
|
- A dict whose keys are words and values are the corresponding
|
||||||
@ -156,12 +158,9 @@ def generate_lexicon(
|
|||||||
for word, pieces in zip(words, words_pieces):
|
for word, pieces in zip(words, words_pieces):
|
||||||
lexicon.append((word, pieces))
|
lexicon.append((word, pieces))
|
||||||
|
|
||||||
# The OOV word is <UNK>
|
lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())]))
|
||||||
lexicon.append(("<UNK>", [sp.id_to_piece(sp.unk_id())]))
|
|
||||||
|
|
||||||
token2id: Dict[str, int] = dict()
|
token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())}
|
||||||
for i in range(sp.vocab_size()):
|
|
||||||
token2id[sp.id_to_piece(i)] = i
|
|
||||||
|
|
||||||
return lexicon, token2id
|
return lexicon, token2id
|
||||||
|
|
||||||
@ -176,6 +175,13 @@ def get_args():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--oov",
|
||||||
|
type=str,
|
||||||
|
default="<UNK>",
|
||||||
|
help="The out of vocabulary word in lexicon.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--debug",
|
"--debug",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -202,12 +208,13 @@ def main():
|
|||||||
|
|
||||||
words = word_sym_table.symbols
|
words = word_sym_table.symbols
|
||||||
|
|
||||||
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", "<UNK>", "#0", "<s>", "</s>"]
|
excluded = ["<eps>", "!SIL", "<SPOKEN_NOISE>", args.oov, "#0", "<s>", "</s>"]
|
||||||
|
|
||||||
for w in excluded:
|
for w in excluded:
|
||||||
if w in words:
|
if w in words:
|
||||||
words.remove(w)
|
words.remove(w)
|
||||||
|
|
||||||
lexicon, token_sym_table = generate_lexicon(model_file, words)
|
lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov)
|
||||||
|
|
||||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||||
|
|
||||||
|
@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
if random.random() >= self.balance_prob:
|
if random.random() >= self.balance_prob:
|
||||||
return x
|
return x
|
||||||
else:
|
|
||||||
return ActivationBalancerFunction.apply(
|
return ActivationBalancerFunction.apply(
|
||||||
x,
|
x,
|
||||||
self.channel_dim,
|
self.channel_dim,
|
||||||
self.min_positive,
|
self.min_positive,
|
||||||
self.max_positive,
|
self.max_positive,
|
||||||
self.max_factor / self.balance_prob,
|
self.max_factor / self.balance_prob,
|
||||||
self.min_abs,
|
self.min_abs,
|
||||||
self.max_abs,
|
self.max_abs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DoubleSwishFunction(torch.autograd.Function):
|
class DoubleSwishFunction(torch.autograd.Function):
|
||||||
|
@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled(
|
|||||||
if not inplace:
|
if not inplace:
|
||||||
model = copy.deepcopy(model)
|
model = copy.deepcopy(model)
|
||||||
|
|
||||||
excluded_patterns = r"self_attn\.(in|out)_proj"
|
excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
|
||||||
p = re.compile(excluded_patterns)
|
p = re.compile(excluded_patterns)
|
||||||
|
|
||||||
d = {}
|
d = {}
|
||||||
|
@ -1,5 +1,88 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### TedLium3 BPE training results (Conformer-CTC 2)
|
||||||
|
|
||||||
|
#### [conformer_ctc2](./conformer_ctc2)
|
||||||
|
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/696> for more details.
|
||||||
|
|
||||||
|
The tensorboard log can be found at
|
||||||
|
<https://tensorboard.dev/experiment/5NQQiqOqSqazfn4w2yeWEQ/>
|
||||||
|
|
||||||
|
You can find a pretrained model and decoding results at:
|
||||||
|
<https://huggingface.co/videodanchik/icefall-asr-tedlium3-conformer-ctc2>
|
||||||
|
|
||||||
|
Number of model parameters: 101141699, i.e., 101.14 M
|
||||||
|
|
||||||
|
The WERs are
|
||||||
|
|
||||||
|
| | dev | test | comment |
|
||||||
|
|--------------------------|------------|-------------|---------------------|
|
||||||
|
| ctc decoding | 6.45 | 5.96 | --epoch 38 --avg 26 |
|
||||||
|
| 1best | 5.92 | 5.51 | --epoch 38 --avg 26 |
|
||||||
|
| whole lattice rescoring | 5.96 | 5.47 | --epoch 38 --avg 26 |
|
||||||
|
| attention decoder | 5.60 | 5.33 | --epoch 38 --avg 26 |
|
||||||
|
|
||||||
|
The training command for reproducing is given below:
|
||||||
|
|
||||||
|
```
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./conformer_ctc2/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 40 \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--max-duration 350 \
|
||||||
|
--use-fp16 true
|
||||||
|
```
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
```
|
||||||
|
epoch=38
|
||||||
|
avg=26
|
||||||
|
|
||||||
|
## ctc decoding
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--method ctc-decoding \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--result-dir conformer_ctc2/exp \
|
||||||
|
--max-duration 500 \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg
|
||||||
|
|
||||||
|
## 1best
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--method 1best \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--result-dir conformer_ctc2/exp \
|
||||||
|
--max-duration 500 \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg
|
||||||
|
|
||||||
|
## whole lattice rescoring
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--method whole-lattice-rescoring \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--lm-path data/lm/G_4_gram_big.pt \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--result-dir conformer_ctc2/exp \
|
||||||
|
--max-duration 500 \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg
|
||||||
|
|
||||||
|
## attention decoder
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--method attention-decoder \
|
||||||
|
--exp-dir conformer_ctc2/exp \
|
||||||
|
--lang-dir data/lang_bpe_500 \
|
||||||
|
--result-dir conformer_ctc2/exp \
|
||||||
|
--max-duration 500 \
|
||||||
|
--epoch $epoch \
|
||||||
|
--avg $avg
|
||||||
|
```
|
||||||
|
|
||||||
### TedLium3 BPE training results (Pruned Transducer)
|
### TedLium3 BPE training results (Pruned Transducer)
|
||||||
|
|
||||||
#### 2022-03-21
|
#### 2022-03-21
|
||||||
|
0
egs/tedlium3/ASR/conformer_ctc2/__init__.py
Executable file
0
egs/tedlium3/ASR/conformer_ctc2/__init__.py
Executable file
1
egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../transducer_stateless/asr_datamodule.py
|
201
egs/tedlium3/ASR/conformer_ctc2/attention.py
Normal file
201
egs/tedlium3/ASR/conformer_ctc2/attention.py
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(torch.nn.Module):
|
||||||
|
"""Allows the model to jointly attend to information
|
||||||
|
from different representation subspaces. This is a modified
|
||||||
|
version of the original version of multihead attention
|
||||||
|
(see Attention Is All You Need <https://arxiv.org/abs/1706.03762>)
|
||||||
|
with replacement of input / output projection layers
|
||||||
|
with newly introduced ScaleLinear layer
|
||||||
|
(see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim:
|
||||||
|
total dimension of the model.
|
||||||
|
num_heads:
|
||||||
|
number of parallel attention heads. Note that embed_dim will be split
|
||||||
|
across num_heads, i.e. each head will have dimension (embed_dim // num_heads).
|
||||||
|
dropout:
|
||||||
|
dropout probability on attn_output_weights. (default=0.0).
|
||||||
|
bias:
|
||||||
|
if specified, adds bias to input / output projection layers (default=True).
|
||||||
|
add_bias_kv:
|
||||||
|
if specified, adds bias to the key and value sequences at dim=0 (default=False).
|
||||||
|
add_zero_attn:
|
||||||
|
if specified, adds a new batch of zeros to the key and value sequences
|
||||||
|
at dim=1 (default=False).
|
||||||
|
batch_first:
|
||||||
|
if True, then the input and output tensors are provided as
|
||||||
|
(batch, seq, feature), otherwise (seq, batch, feature) (default=False).
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
>>> multihead_attn = MultiheadAttention(embed_dim, num_heads)
|
||||||
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
add_bias_kv: bool = False,
|
||||||
|
add_zero_attn: bool = False,
|
||||||
|
batch_first: bool = False,
|
||||||
|
device: Union[torch.device, str, None] = None,
|
||||||
|
dtype: Union[torch.dtype, str, None] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.batch_first = batch_first
|
||||||
|
|
||||||
|
if embed_dim % num_heads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads. "
|
||||||
|
"Got embedding dim vs number 0f heads: "
|
||||||
|
f"{embed_dim} vs {num_heads}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
self.in_proj = ScaledLinear(
|
||||||
|
embed_dim,
|
||||||
|
3 * embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.out_proj = ScaledLinear(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=bias,
|
||||||
|
initial_scale=0.25,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = torch.nn.Parameter(
|
||||||
|
torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
self.bias_v = torch.nn.Parameter(
|
||||||
|
torch.empty((1, 1, embed_dim), device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias_k", None)
|
||||||
|
self.register_parameter("bias_v", None)
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self) -> None:
|
||||||
|
if self.bias_k is not None:
|
||||||
|
torch.nn.init.xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
torch.nn.init.xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
query:
|
||||||
|
Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q)
|
||||||
|
when batch_first=True, where L is the target sequence length, N is the batch size,
|
||||||
|
and E_q is the query embedding dimension embed_dim. Queries are compared against
|
||||||
|
key-value pairs to produce the output. See "Attention Is All You Need" for more details.
|
||||||
|
key:
|
||||||
|
Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when
|
||||||
|
batch_first=True, where S is the source sequence length, N is the batch size, and
|
||||||
|
E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details.
|
||||||
|
value:
|
||||||
|
Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when
|
||||||
|
batch_first=True, where S is the source sequence length, N is the batch size, and
|
||||||
|
E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details.
|
||||||
|
key_padding_mask:
|
||||||
|
If specified, a mask of shape (N, S) indicating which elements within key
|
||||||
|
to ignore for the purpose of attention (i.e. treat as "padding").
|
||||||
|
Binary and byte masks are supported. For a binary mask, a True value indicates
|
||||||
|
that the corresponding key value will be ignored for the purpose of attention.
|
||||||
|
For a byte mask, a non-zero value indicates that the corresponding key value will be ignored.
|
||||||
|
need_weights:
|
||||||
|
If specifid, returns attn_output_weights in addition to attn_outputs (default=True).
|
||||||
|
attn_mask:
|
||||||
|
If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
||||||
|
(L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length,
|
||||||
|
and S is the source sequence length. A 2D mask will be broadcasted across the batch while
|
||||||
|
a 3D mask allows for a different mask for each entry in the batch.
|
||||||
|
Binary, byte, and float masks are supported. For a binary mask, a True value indicates
|
||||||
|
that the corresponding position is not allowed to attend. For a byte mask, a non-zero
|
||||||
|
value indicates that the corresponding position is not allowed to attend. For a float mask,
|
||||||
|
the mask values will be added to the attention weight.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
attn_output:
|
||||||
|
Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True,
|
||||||
|
where L is the target sequence length, N is the batch size, and E is the embedding dimension
|
||||||
|
embed_dim.
|
||||||
|
attn_output_weights:
|
||||||
|
Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence
|
||||||
|
length, and S is the source sequence length. Only returned when need_weights=True.
|
||||||
|
"""
|
||||||
|
if self.batch_first:
|
||||||
|
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
|
||||||
|
|
||||||
|
(
|
||||||
|
attn_output,
|
||||||
|
attn_output_weights,
|
||||||
|
) = torch.nn.functional.multi_head_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
in_proj_weight=self.in_proj.get_weight(),
|
||||||
|
in_proj_bias=self.in_proj.get_bias(),
|
||||||
|
bias_k=self.bias_k,
|
||||||
|
bias_v=self.bias_v,
|
||||||
|
add_zero_attn=self.add_zero_attn,
|
||||||
|
dropout_p=self.dropout,
|
||||||
|
out_proj_weight=self.out_proj.get_weight(),
|
||||||
|
out_proj_bias=self.out_proj.get_bias(),
|
||||||
|
training=self.training,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.batch_first:
|
||||||
|
return attn_output.transpose(1, 0), attn_output_weights
|
||||||
|
return attn_output, attn_output_weights
|
244
egs/tedlium3/ASR/conformer_ctc2/combiner.py
Normal file
244
egs/tedlium3/ASR/conformer_ctc2/combiner.py
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCombine(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This module combines a list of Tensors, all with the same shape, to
|
||||||
|
produce a single output of that same shape which, in training time,
|
||||||
|
is a random combination of all the inputs; but which in test time
|
||||||
|
will be just the last input.
|
||||||
|
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||||
|
conformer layers. This has a similar effect as iterated loss. (See:
|
||||||
|
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||||
|
NETWORKS).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inputs: int,
|
||||||
|
final_weight: float = 0.5,
|
||||||
|
pure_prob: float = 0.5,
|
||||||
|
stddev: float = 2.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_inputs:
|
||||||
|
The number of tensor inputs, which equals the number of layers'
|
||||||
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
|
final_weight:
|
||||||
|
The amount of weight or probability we assign to the
|
||||||
|
final layer when randomly choosing layers or when choosing
|
||||||
|
continuous layer weights.
|
||||||
|
pure_prob:
|
||||||
|
The probability, on each frame, with which we choose
|
||||||
|
only a single layer to output (rather than an interpolation)
|
||||||
|
stddev:
|
||||||
|
A standard deviation that we add to log-probs for computing
|
||||||
|
randomized weights.
|
||||||
|
The method of choosing which layers, or combinations of layers, to use,
|
||||||
|
is conceptually as follows::
|
||||||
|
With probability `pure_prob`::
|
||||||
|
With probability `final_weight`: choose final layer,
|
||||||
|
Else: choose random non-final layer.
|
||||||
|
Else::
|
||||||
|
Choose initial log-weights that correspond to assigning
|
||||||
|
weight `final_weight` to the final layer and equal
|
||||||
|
weights to other layers; then add Gaussian noise
|
||||||
|
with variance `stddev` to these log-weights, and normalize
|
||||||
|
to weights (note: the average weight assigned to the
|
||||||
|
final layer here will not be `final_weight` if stddev>0).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0 <= pure_prob <= 1, pure_prob
|
||||||
|
assert 0 < final_weight < 1, final_weight
|
||||||
|
assert num_inputs >= 1, num_inputs
|
||||||
|
|
||||||
|
self.num_inputs = num_inputs
|
||||||
|
self.final_weight = final_weight
|
||||||
|
self.pure_prob = pure_prob
|
||||||
|
self.stddev = stddev
|
||||||
|
|
||||||
|
self.final_log_weight = (
|
||||||
|
torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1))
|
||||||
|
.log()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
inputs:
|
||||||
|
A list of Tensor, e.g. from various layers of a transformer.
|
||||||
|
All must be the same shape, of (*, num_channels)
|
||||||
|
Returns:
|
||||||
|
A Tensor of shape (*, num_channels). In test mode
|
||||||
|
this is just the final input.
|
||||||
|
"""
|
||||||
|
num_inputs = self.num_inputs
|
||||||
|
assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}"
|
||||||
|
if not self.training or torch.jit.is_scripting() or len(inputs) == 1:
|
||||||
|
return inputs[-1]
|
||||||
|
|
||||||
|
# Shape of weights: (*, num_inputs)
|
||||||
|
num_channels = inputs[0].shape[-1]
|
||||||
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
|
ndim = inputs[0].ndim
|
||||||
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
|
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||||
|
(num_frames, num_channels, num_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# weights: (num_frames, num_inputs)
|
||||||
|
weights = self._get_random_weights(
|
||||||
|
inputs[0].dtype, inputs[0].device, num_frames
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = weights.reshape(num_frames, num_inputs, 1)
|
||||||
|
# ans: (num_frames, num_channels, 1)
|
||||||
|
ans = torch.matmul(stacked_inputs, weights)
|
||||||
|
# ans: (*, num_channels)
|
||||||
|
|
||||||
|
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_random_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Return a tensor of random weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_frames, self.num_inputs), such that
|
||||||
|
`ans.sum(dim=1)` is all ones.
|
||||||
|
"""
|
||||||
|
pure_prob = self.pure_prob
|
||||||
|
if pure_prob == 0.0:
|
||||||
|
return self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
elif pure_prob == 1.0:
|
||||||
|
return self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
else:
|
||||||
|
p = self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
return torch.where(
|
||||||
|
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_random_pure_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Return a tensor of random one-hot weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired.
|
||||||
|
Returns:
|
||||||
|
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
|
||||||
|
exactly one weight equal to 1.0 on each frame.
|
||||||
|
"""
|
||||||
|
final_prob = self.final_weight
|
||||||
|
|
||||||
|
# final contains self.num_inputs - 1 in all elements
|
||||||
|
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
||||||
|
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
|
||||||
|
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
|
||||||
|
|
||||||
|
indexes = torch.where(
|
||||||
|
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
|
||||||
|
)
|
||||||
|
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_random_mixed_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Return a tensor of random one-hot weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired.
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_frames, self.num_inputs), which elements
|
||||||
|
in [0..1] that sum to one over the second axis, i.e.
|
||||||
|
`ans.sum(dim=1)` is all ones.
|
||||||
|
"""
|
||||||
|
logprobs = (
|
||||||
|
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
|
||||||
|
* self.stddev
|
||||||
|
)
|
||||||
|
logprobs[:, -1] += self.final_log_weight
|
||||||
|
return logprobs.softmax(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine(
|
||||||
|
final_weight: float,
|
||||||
|
pure_prob: float,
|
||||||
|
stddev: float,
|
||||||
|
) -> None:
|
||||||
|
print(
|
||||||
|
f"_test_random_combine: final_weight={final_weight}, "
|
||||||
|
f"pure_prob={pure_prob}, stddev={stddev}"
|
||||||
|
)
|
||||||
|
num_inputs = 3
|
||||||
|
num_channels = 50
|
||||||
|
m = RandomCombine(
|
||||||
|
num_inputs=num_inputs,
|
||||||
|
final_weight=final_weight,
|
||||||
|
pure_prob=pure_prob,
|
||||||
|
stddev=stddev,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
assert y.shape == x[0].shape
|
||||||
|
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine_main() -> None:
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.0)
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.3)
|
||||||
|
_test_random_combine(0.5, 1, 0.3)
|
||||||
|
_test_random_combine(0.5, 0.5, 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_random_combine_main()
|
1033
egs/tedlium3/ASR/conformer_ctc2/conformer.py
Normal file
1033
egs/tedlium3/ASR/conformer_ctc2/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
899
egs/tedlium3/ASR/conformer_ctc2/decode.py
Executable file
899
egs/tedlium3/ASR/conformer_ctc2/decode.py
Executable file
@ -0,0 +1,899 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||||
|
# Fangjun Kuang,
|
||||||
|
# Quandong Wang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import TedLiumAsrDataModule
|
||||||
|
from conformer import Conformer
|
||||||
|
from train import add_model_arguments
|
||||||
|
|
||||||
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
nbest_decoding,
|
||||||
|
nbest_oracle,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_n_best_list,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
|
from icefall.env import get_env_info
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
get_texts,
|
||||||
|
load_averaged_model,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="attention-decoder",
|
||||||
|
help="""Decoding method.
|
||||||
|
Supported values are:
|
||||||
|
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||||
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
|
- (1) ctc-greedy-search. It only use CTC output and a sentence piece
|
||||||
|
model for decoding. It produces the same results with ctc-decoding.
|
||||||
|
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||||
|
decoding result.
|
||||||
|
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||||
|
with the highest score is the decoding result.
|
||||||
|
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
|
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||||
|
the highest score is the decoding result.
|
||||||
|
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||||
|
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||||
|
is the decoding result.
|
||||||
|
- (6) attention-decoder. Extract n paths from the LM rescored
|
||||||
|
lattice, the path with the highest score is the decoding result.
|
||||||
|
- (7) nbest-oracle. Its WER is the lower bound of any n-best
|
||||||
|
rescoring method can achieve. Useful for debugging n-best
|
||||||
|
rescoring method.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""Number of paths for n-best based decoding method.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nbest-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""The scale to be applied to `lattice.scores`.
|
||||||
|
It's needed if you use any kinds of n-best based rescoring.
|
||||||
|
Used only when "method" is one of the following values:
|
||||||
|
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||||
|
A smaller value results in more unique paths.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-path",
|
||||||
|
type=str,
|
||||||
|
default="data/lm/G_4_gram.pt",
|
||||||
|
help="""The n-gram LM dir for rescoring.
|
||||||
|
It should contain either lm_fname.pt or lm_fname.fst.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--result-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="Directory to store results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
are saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
"""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
# parameters for conformer
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"feature_dim": 80,
|
||||||
|
# parameters for decoding
|
||||||
|
"search_beam": 15,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 10,
|
||||||
|
"max_active_states": 7000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
"env_info": get_env_info(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_greedy_search(
|
||||||
|
ctc_probs: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""Apply CTC greedy search
|
||||||
|
Args:
|
||||||
|
ctc_probs (torch.Tensor): (batch, max_len, num_bpe)
|
||||||
|
mask (torch.Tensor): (batch, max_len)
|
||||||
|
Returns:
|
||||||
|
best path result
|
||||||
|
"""
|
||||||
|
|
||||||
|
_, max_index = ctc_probs.max(2) # (B, maxlen)
|
||||||
|
max_index = max_index.masked_fill_(mask, 0) # (B, maxlen)
|
||||||
|
|
||||||
|
ret_hyps = []
|
||||||
|
for hyp in max_index:
|
||||||
|
hyp = torch.unique_consecutive(hyp)
|
||||||
|
hyp = hyp[hyp > 0].tolist()
|
||||||
|
ret_hyps.append(hyp)
|
||||||
|
return ret_hyps
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
|
batch: dict,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if no rescoring is used, the key is the string `no_rescore`.
|
||||||
|
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||||
|
where `xxx` is the value of `lm_scale`. An example key is
|
||||||
|
`lm_scale_0.7`
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
|
||||||
|
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||||
|
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||||
|
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||||
|
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||||
|
rescoring.
|
||||||
|
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.method is ctc-decoding.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID of the SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID of the EOS.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
|
"""
|
||||||
|
if HLG is not None:
|
||||||
|
device = HLG.device
|
||||||
|
else:
|
||||||
|
device = H.device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
|
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||||
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
(
|
||||||
|
supervisions["sequence_idx"],
|
||||||
|
torch.div(
|
||||||
|
supervisions["start_frame"],
|
||||||
|
params.subsampling_factor,
|
||||||
|
rounding_mode="floor",
|
||||||
|
),
|
||||||
|
torch.div(
|
||||||
|
supervisions["num_frames"],
|
||||||
|
params.subsampling_factor,
|
||||||
|
rounding_mode="floor",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
1,
|
||||||
|
).to(torch.int32)
|
||||||
|
|
||||||
|
if H is None:
|
||||||
|
assert HLG is not None
|
||||||
|
decoding_graph = HLG
|
||||||
|
else:
|
||||||
|
assert HLG is None
|
||||||
|
assert bpe_model is not None
|
||||||
|
decoding_graph = H
|
||||||
|
|
||||||
|
lattice = get_lattice(
|
||||||
|
nnet_output=nnet_output,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
search_beam=params.search_beam,
|
||||||
|
output_beam=params.output_beam,
|
||||||
|
min_active_states=params.min_active_states,
|
||||||
|
max_active_states=params.max_active_states,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "ctc-decoding":
|
||||||
|
best_path = one_best_decoding(
|
||||||
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
)
|
||||||
|
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||||
|
# since we are using H, not HLG here.
|
||||||
|
#
|
||||||
|
# token_ids is a lit-of-list of IDs
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
unk = bpe_model.decode(bpe_model.unk_id()).strip()
|
||||||
|
hyps = [[w for w in s.split() if w != unk] for s in hyps]
|
||||||
|
key = "ctc-decoding"
|
||||||
|
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method == "ctc-greedy-search":
|
||||||
|
hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask)
|
||||||
|
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(hyps)
|
||||||
|
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
unk = bpe_model.decode(bpe_model.unk_id()).strip()
|
||||||
|
hyps = [[w for w in s.split() if w != unk] for s in hyps]
|
||||||
|
key = "ctc-greedy-search"
|
||||||
|
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method == "nbest-oracle":
|
||||||
|
# Note: You can also pass rescored lattices to it.
|
||||||
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
|
# as HLG decoding is faster and the oracle WER
|
||||||
|
# is only slightly worse than that of rescored lattices.
|
||||||
|
best_path = nbest_oracle(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
ref_texts=supervisions["text"],
|
||||||
|
word_table=word_table,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
oov="<unk>",
|
||||||
|
)
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [
|
||||||
|
[word_table[i] for i in ids if word_table[i] != "<unk>"] for ids in hyps
|
||||||
|
]
|
||||||
|
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method == "nbest":
|
||||||
|
best_path = nbest_decoding(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [
|
||||||
|
[word_table[i] for i in ids if word_table[i] != "<unk>"] for ids in hyps
|
||||||
|
]
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
assert params.method in [
|
||||||
|
"1best",
|
||||||
|
"nbest-rescoring",
|
||||||
|
"whole-lattice-rescoring",
|
||||||
|
"attention-decoder",
|
||||||
|
]
|
||||||
|
|
||||||
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
|
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
best_path_dict = one_best_decoding(
|
||||||
|
lattice=lattice,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
|
elif params.method == "nbest-rescoring":
|
||||||
|
best_path_dict = rescore_with_n_best_list(
|
||||||
|
lattice=lattice,
|
||||||
|
G=G,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=lm_scale_list,
|
||||||
|
)
|
||||||
|
elif params.method == "attention-decoder":
|
||||||
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
nbest_scale=params.nbest_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
if best_path_dict is not None:
|
||||||
|
for lm_scale_str, best_path in best_path_dict.items():
|
||||||
|
hyps = get_texts(best_path)
|
||||||
|
hyps = [
|
||||||
|
[word_table[i] for i in ids if word_table[i] != "<unk>"] for ids in hyps
|
||||||
|
]
|
||||||
|
ans[lm_scale_str] = hyps
|
||||||
|
else:
|
||||||
|
ans = None
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
HLG: Optional[k2.Fsa],
|
||||||
|
H: Optional[k2.Fsa],
|
||||||
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
G: Optional[k2.Fsa] = None,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
HLG:
|
||||||
|
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||||
|
H:
|
||||||
|
The ctc topo. Used only when params.method is ctc-decoding.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
word_table:
|
||||||
|
It is the word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
G:
|
||||||
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
|
is a 3-gram LM, while this G is a 4-gram LM.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||||
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
bpe_model=bpe_model,
|
||||||
|
batch=batch,
|
||||||
|
word_table=word_table,
|
||||||
|
G=G,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hyps_dict is not None:
|
||||||
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((cut_id, ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
else:
|
||||||
|
assert len(results) > 0, "It should not decode to empty in the first batch!"
|
||||||
|
this_batch = []
|
||||||
|
hyp_words = []
|
||||||
|
for ref_text in texts:
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
for lm_scale in results.keys():
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
) -> None:
|
||||||
|
if params.method == "attention-decoder":
|
||||||
|
# Set it to False since there are too many logs.
|
||||||
|
enable_log = False
|
||||||
|
else:
|
||||||
|
enable_log = True
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
if enable_log:
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt"
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main() -> None:
|
||||||
|
parser = get_parser()
|
||||||
|
TedLiumAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
args.lm_path = Path(args.lm_path)
|
||||||
|
args.result_dir = Path(args.result_dir)
|
||||||
|
|
||||||
|
if args.result_dir.is_dir():
|
||||||
|
shutil.rmtree(args.result_dir)
|
||||||
|
args.result_dir.mkdir()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
|
params.lang_dir,
|
||||||
|
device=device,
|
||||||
|
sos_token="<sos/eos>",
|
||||||
|
eos_token="<sos/eos>",
|
||||||
|
)
|
||||||
|
sos_id = graph_compiler.sos_id
|
||||||
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
|
if params.method in ("ctc-decoding", "ctc-greedy-search"):
|
||||||
|
HLG = None
|
||||||
|
H = k2.ctc_topo(
|
||||||
|
max_token=max_token_id,
|
||||||
|
modified=False,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||||
|
else:
|
||||||
|
H = None
|
||||||
|
bpe_model = None
|
||||||
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||||
|
)
|
||||||
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method in ("nbest-rescoring", "whole-lattice-rescoring"):
|
||||||
|
assert params.lm_path.suffix in (".pt", ".txt")
|
||||||
|
|
||||||
|
if params.lm_path.is_file() and params.lm_path.suffix == ".pt":
|
||||||
|
logging.info(f"Loading pre-compiled {params.lm_path.name}")
|
||||||
|
d = torch.load(params.lm_path, map_location=device)
|
||||||
|
G = k2.Fsa.from_dict(d)
|
||||||
|
elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt":
|
||||||
|
raise FileNotFoundError(f"No such language model file: '{params.lm_path}'")
|
||||||
|
else:
|
||||||
|
# here we pass only if LM filename ends with '.pt' and doesn't exist
|
||||||
|
# or if LM filename ends '.txt' and exists.
|
||||||
|
if (
|
||||||
|
not params.lm_path.is_file()
|
||||||
|
and params.lm_path.suffix == ".pt"
|
||||||
|
and not (
|
||||||
|
params.lm_path.parent / f"{params.lm_path.stem}.fst.txt"
|
||||||
|
).is_file()
|
||||||
|
):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No such language model file: '{params.lm_path}'\n"
|
||||||
|
"'.fst.txt' representation of the language model was "
|
||||||
|
"not found either."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt
|
||||||
|
# we are going to load lm_name.fst.txt here
|
||||||
|
params.lm_path = params.lm_path.parent / params.lm_path.name.replace(
|
||||||
|
".pt", ".fst.txt"
|
||||||
|
)
|
||||||
|
logging.info(f"Loading {params.lm_path.name}")
|
||||||
|
logging.warning("It may take 8 minutes.")
|
||||||
|
with open(params.lm_path) as f:
|
||||||
|
first_word_disambig_id = lexicon.word_table["#0"]
|
||||||
|
|
||||||
|
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||||
|
# G.aux_labels is not needed in later computations, so
|
||||||
|
# remove it here.
|
||||||
|
del G.aux_labels
|
||||||
|
# CAUTION: The following line is crucial.
|
||||||
|
# Arcs entering the back-off state have label equal to #0.
|
||||||
|
# We have to change it to 0 here.
|
||||||
|
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||||
|
# See https://github.com/k2-fsa/k2/issues/874
|
||||||
|
# for why we need to set G.properties to None
|
||||||
|
G.__dict__["_properties"] = None
|
||||||
|
G = k2.Fsa.from_fsas([G]).to(device)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
# Save a dummy value so that it can be loaded in C++.
|
||||||
|
# See https://github.com/pytorch/pytorch/issues/67902
|
||||||
|
# for why we need to do this.
|
||||||
|
G.dummy = 1
|
||||||
|
|
||||||
|
torch.save(
|
||||||
|
G.as_dict(),
|
||||||
|
params.lm_path.parent
|
||||||
|
/ params.lm_path.name.replace(".fst.txt", ".pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G = G.to(device)
|
||||||
|
|
||||||
|
# G.lm_scores is used to replace HLG.lm_scores during
|
||||||
|
# LM rescoring.
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
else:
|
||||||
|
G = None
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
d_model=params.dim_model,
|
||||||
|
nhead=params.nhead,
|
||||||
|
dim_feedforward=params.dim_feedforward,
|
||||||
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
tedlium = TedLiumAsrDataModule(args)
|
||||||
|
|
||||||
|
valid_cuts = tedlium.dev_cuts()
|
||||||
|
test_cuts = tedlium.test_cuts()
|
||||||
|
|
||||||
|
valid_dl = tedlium.valid_dataloaders(valid_cuts)
|
||||||
|
test_dl = tedlium.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_dls = [valid_dl, test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
HLG=HLG,
|
||||||
|
H=H,
|
||||||
|
bpe_model=bpe_model,
|
||||||
|
word_table=lexicon.word_table,
|
||||||
|
G=G,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
# when we import add_model_arguments from train.py
|
||||||
|
# we enforce torch.set_num_interop_threads(1) in it,
|
||||||
|
# so we ended up with setting num_interop_threads to one
|
||||||
|
# two times: in train.py and decode.py which cause an error,
|
||||||
|
# that is why added an additional if statement.
|
||||||
|
if torch.get_num_interop_threads() != 1:
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
|
# in PyTorch 1.12 and later.
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
294
egs/tedlium3/ASR/conformer_ctc2/export.py
Executable file
294
egs/tedlium3/ASR/conformer_ctc2/export.py
Executable file
@ -0,0 +1,294 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2022 Behavox LLC (Author: Daniil Kulko)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# This script converts several saved checkpoints
|
||||||
|
# to a single one using model averaging.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
./conformer_ctc2/export.py \
|
||||||
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10
|
||||||
|
|
||||||
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
|
To use the generated file with `conformer_ctc2/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
|
cd /path/to/exp_dir
|
||||||
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
|
cd /path/to/egs/tedlium3/ASR
|
||||||
|
./conformer_ctc2/decode.py \
|
||||||
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--epoch 9999 \
|
||||||
|
--avg 1 \
|
||||||
|
--max-duration 100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from conformer import Conformer
|
||||||
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
|
from train import add_model_arguments
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import AttributeDict, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
|
Note: Epoch counts from 0.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help=(
|
||||||
|
"Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help=(
|
||||||
|
"Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. "
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="conformer_ctc2/exp",
|
||||||
|
help="""It specifies the directory where all training related
|
||||||
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--jit",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="""True to save a model after applying torch.jit.script.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing training parameters.
|
||||||
|
|
||||||
|
All training related parameters that are not passed from the commandline
|
||||||
|
are saved in the variable `params`.
|
||||||
|
|
||||||
|
Commandline options are merged into `params` after they are parsed, so
|
||||||
|
you can also access them via `params`.
|
||||||
|
|
||||||
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
|
- feature_dim: The model input dim. It has to match the one used
|
||||||
|
in computing features.
|
||||||
|
|
||||||
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
"""
|
||||||
|
# parameters for conformer
|
||||||
|
params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80})
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser().parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
d_model=params.dim_model,
|
||||||
|
nhead=params.nhead,
|
||||||
|
dim_feedforward=params.dim_feedforward,
|
||||||
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to("cpu")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.jit:
|
||||||
|
convert_scaled_to_non_scaled(model, inplace=True)
|
||||||
|
logging.info("Using torch.jit.script")
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
model.save(str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
else:
|
||||||
|
logging.info("Not using torch.jit.script")
|
||||||
|
# Save it using a format so that it can be loaded
|
||||||
|
# by :func:`load_checkpoint`
|
||||||
|
filename = params.exp_dir / "pretrained.pt"
|
||||||
|
torch.save({"model": model.state_dict()}, str(filename))
|
||||||
|
logging.info(f"Saved to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
main()
|
1
egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/conformer_ctc/label_smoothing.py
|
1
egs/tedlium3/ASR/conformer_ctc2/lstmp.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/lstmp.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py
|
1
egs/tedlium3/ASR/conformer_ctc2/optim.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/optim.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless2/optim.py
|
1
egs/tedlium3/ASR/conformer_ctc2/scaling.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/scaling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py
|
1
egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py
|
1
egs/tedlium3/ASR/conformer_ctc2/subsampling.py
Symbolic link
1
egs/tedlium3/ASR/conformer_ctc2/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../../../librispeech/ASR/conformer_ctc2/subsampling.py
|
1061
egs/tedlium3/ASR/conformer_ctc2/train.py
Executable file
1061
egs/tedlium3/ASR/conformer_ctc2/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1093
egs/tedlium3/ASR/conformer_ctc2/transformer.py
Normal file
1093
egs/tedlium3/ASR/conformer_ctc2/transformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -4,16 +4,18 @@
|
|||||||
"""
|
"""
|
||||||
Convert a transcript based on words to a list of BPE ids.
|
Convert a transcript based on words to a list of BPE ids.
|
||||||
|
|
||||||
For example, if we use 2 as the encoding id of <unk>:
|
For example, if we use 2 as the encoding id of <unk>
|
||||||
|
Note: it, inserts a space token before each <unk>
|
||||||
|
|
||||||
texts = ['this is a <unk> day']
|
texts = ['this is a <unk> day']
|
||||||
spm_ids = [[38, 33, 6, 2, 316]]
|
spm_ids = [[38, 33, 6, 15, 2, 316]]
|
||||||
|
|
||||||
texts = ['<unk> this is a sunny day']
|
texts = ['<unk> this is a sunny day']
|
||||||
spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]]
|
spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]]
|
||||||
|
|
||||||
texts = ['<unk>']
|
texts = ['<unk>']
|
||||||
spm_ids = [[2]]
|
spm_ids = [[15, 2]]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -38,29 +40,27 @@ def get_args():
|
|||||||
|
|
||||||
def convert_texts_into_ids(
|
def convert_texts_into_ids(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
unk_id: int,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
texts:
|
texts:
|
||||||
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
|
A string list of transcripts, such as ['Today is Monday', 'It's sunny'].
|
||||||
unk_id:
|
sp:
|
||||||
A number id for the token '<unk>'.
|
A sentencepiece BPE model.
|
||||||
Returns:
|
Returns:
|
||||||
Return an integer list of bpe ids.
|
Return an integer list of bpe ids.
|
||||||
"""
|
"""
|
||||||
y = []
|
y = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
y_ids = []
|
|
||||||
if "<unk>" in text:
|
if "<unk>" in text:
|
||||||
text_segments = text.split("<unk>")
|
id_segments = sp.encode(text.split("<unk>"), out_type=int)
|
||||||
id_segments = sp.encode(text_segments, out_type=int)
|
|
||||||
|
y_ids = []
|
||||||
for i in range(len(id_segments)):
|
for i in range(len(id_segments)):
|
||||||
if i != len(id_segments) - 1:
|
y_ids += id_segments[i]
|
||||||
y_ids.extend(id_segments[i] + [unk_id])
|
if i < len(id_segments) - 1:
|
||||||
else:
|
y_ids += [sp.piece_to_id("▁"), sp.unk_id()]
|
||||||
y_ids.extend(id_segments[i])
|
|
||||||
else:
|
else:
|
||||||
y_ids = sp.encode(text, out_type=int)
|
y_ids = sp.encode(text, out_type=int)
|
||||||
y.append(y_ids)
|
y.append(y_ids)
|
||||||
@ -70,19 +70,13 @@ def convert_texts_into_ids(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = get_args()
|
args = get_args()
|
||||||
texts = args.texts
|
|
||||||
bpe_model = args.bpe_model
|
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(bpe_model)
|
sp.load(args.bpe_model)
|
||||||
unk_id = sp.piece_to_id("<unk>")
|
|
||||||
|
|
||||||
y = convert_texts_into_ids(
|
y = convert_texts_into_ids(texts=args.texts, sp=sp)
|
||||||
texts=texts,
|
|
||||||
unk_id=unk_id,
|
logging.info(f"The input texts: {args.texts}")
|
||||||
sp=sp,
|
|
||||||
)
|
|
||||||
logging.info(f"The input texts: {texts}")
|
|
||||||
logging.info(f"The encoding ids: {y}")
|
logging.info(f"The encoding ids: {y}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/generate_unique_lexicon.py
|
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/prepare_lang.py
|
|
@ -1,94 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
This script takes as input supervisions json dir "data/manifests"
|
|
||||||
consisting of supervisions_train.json and does the following:
|
|
||||||
|
|
||||||
1. Generate lexicon_words.txt.
|
|
||||||
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import lhotse
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--manifests-dir",
|
|
||||||
type=str,
|
|
||||||
help="""Input directory.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--lang-dir",
|
|
||||||
type=str,
|
|
||||||
help="""Output directory.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_lexicon(manifests_dir: str, lang_dir: str):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
manifests_dir:
|
|
||||||
The manifests directory, e.g., data/manifests.
|
|
||||||
lang_dir:
|
|
||||||
The language directory, e.g., data/lang_phone.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
The lexicon_words.txt file.
|
|
||||||
"""
|
|
||||||
words = set()
|
|
||||||
|
|
||||||
lexicon = Path(lang_dir) / "lexicon_words.txt"
|
|
||||||
sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
|
|
||||||
for s in sups:
|
|
||||||
# list the words units and filter the empty item
|
|
||||||
words_list = list(filter(None, s.text.split()))
|
|
||||||
|
|
||||||
for word in words_list:
|
|
||||||
if word not in words and word != "<unk>":
|
|
||||||
words.add(word)
|
|
||||||
|
|
||||||
with open(lexicon, "w") as f:
|
|
||||||
for word in sorted(words):
|
|
||||||
f.write(word + " " + word)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
manifests_dir = Path(args.manifests_dir)
|
|
||||||
lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
logging.info("Generating lexicon_words.txt")
|
|
||||||
prepare_lexicon(manifests_dir, lang_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
||||||
|
|
||||||
main()
|
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
# Copyright 2021 Xiaomi Corp. (author: Mingshuang Luo)
|
||||||
|
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -17,68 +18,67 @@
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script takes as input supervisions json dir "data/manifests"
|
This script takes input text file and removes all words
|
||||||
consisting of supervisions_train.json and does the following:
|
that iclude any character out of English alphabet.
|
||||||
|
|
||||||
1. Generate train.text.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import lhotse
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--manifests-dir",
|
"--input-text-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Input directory.
|
help="Input text file path.",
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--output-text-path",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Output directory.
|
help="Output text file path.",
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def prepare_transcripts(manifests_dir: str, lang_dir: str):
|
def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
manifests_dir:
|
input_text_path:
|
||||||
The manifests directory, e.g., data/manifests.
|
The input data text file path, e.g., data/lang/train_orig.txt.
|
||||||
lang_dir:
|
output_text_path:
|
||||||
The language directory, e.g., data/lang_phone.
|
The output data text file path, e.g., data/lang/train.txt.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
The train.text in lang_dir.
|
Saved text file in output_text_path.
|
||||||
"""
|
"""
|
||||||
texts = []
|
|
||||||
|
|
||||||
train_text = Path(lang_dir) / "train.text"
|
foreign_chr_check = re.compile(r"[^a-z']")
|
||||||
sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz")
|
|
||||||
for s in sups:
|
|
||||||
texts.append(s.text)
|
|
||||||
|
|
||||||
with open(train_text, "w") as f:
|
logging.info(f"Loading {input_text_path.name}")
|
||||||
for text in texts:
|
with open(input_text_path, "r", encoding="utf8") as f:
|
||||||
f.write(text)
|
texts = {t.rstrip("\n") for t in f}
|
||||||
f.write("\n")
|
|
||||||
|
texts = {
|
||||||
|
" ".join([w for w in t.split() if foreign_chr_check.search(w) is None])
|
||||||
|
for t in texts
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_text_path, "w+", encoding="utf8") as f:
|
||||||
|
for t in texts:
|
||||||
|
f.write(f"{t}\n")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
args = get_args()
|
args = get_args()
|
||||||
manifests_dir = Path(args.manifests_dir)
|
input_text_path = Path(args.input_text_path)
|
||||||
lang_dir = Path(args.lang_dir)
|
output_text_path = Path(args.output_text_path)
|
||||||
|
|
||||||
logging.info("Generating train.text")
|
logging.info(f"Generating {output_text_path.name}")
|
||||||
prepare_transcripts(manifests_dir, lang_dir)
|
prepare_transcripts(input_text_path, output_text_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
83
egs/tedlium3/ASR/local/prepare_words.py
Executable file
83
egs/tedlium3/ASR/local/prepare_words.py
Executable file
@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Behavox LLC. (authors: Daniil Kulko)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script takes as input supervisions json dir "data/manifests"
|
||||||
|
consisting of tedlium_supervisions_train.json and does the following:
|
||||||
|
|
||||||
|
1. Generate words.txt.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="Output directory.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_words(lang_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lang_dir:
|
||||||
|
The language directory, e.g., data/lang.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The words.txt file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
words_orig_path = Path(lang_dir) / "words_orig.txt"
|
||||||
|
words_path = Path(lang_dir) / "words.txt"
|
||||||
|
|
||||||
|
foreign_chr_check = re.compile(r"[^a-z']")
|
||||||
|
|
||||||
|
logging.info(f"Loading {words_orig_path.name}")
|
||||||
|
with open(words_orig_path, "r", encoding="utf8") as f:
|
||||||
|
words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")}
|
||||||
|
words = {w for w in words if foreign_chr_check.search(w) is None and w != ""}
|
||||||
|
words.add("<unk>")
|
||||||
|
words = ["<eps>", "!SIL"] + sorted(words) + ["#0", "<s>", "</s>"]
|
||||||
|
|
||||||
|
with open(words_path, "w+", encoding="utf8") as f:
|
||||||
|
for idx, word in enumerate(words):
|
||||||
|
f.write(f"{word} {idx}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
logging.info("Generating words.txt")
|
||||||
|
prepare_words(lang_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/local/test_prepare_lang.py
|
|
@ -5,7 +5,6 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
|||||||
|
|
||||||
set -eou pipefail
|
set -eou pipefail
|
||||||
|
|
||||||
nj=15
|
|
||||||
stage=0
|
stage=0
|
||||||
stop_stage=100
|
stop_stage=100
|
||||||
|
|
||||||
@ -63,6 +62,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
|||||||
mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3
|
mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Download big and small 4 gram lanuage models
|
||||||
|
if [ ! -d $dl_dir/lm ]; then
|
||||||
|
wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm
|
||||||
|
wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm
|
||||||
|
gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz
|
||||||
|
fi
|
||||||
|
|
||||||
# If you have pre-downloaded it to /path/to/musan,
|
# If you have pre-downloaded it to /path/to/musan,
|
||||||
# you can create a symlink
|
# you can create a symlink
|
||||||
#
|
#
|
||||||
@ -100,7 +106,14 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
|||||||
|
|
||||||
if [ ! -e data/fbank/.tedlium3.done ]; then
|
if [ ! -e data/fbank/.tedlium3.done ]; then
|
||||||
mkdir -p data/fbank
|
mkdir -p data/fbank
|
||||||
|
|
||||||
python3 ./local/compute_fbank_tedlium.py
|
python3 ./local/compute_fbank_tedlium.py
|
||||||
|
|
||||||
|
gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \
|
||||||
|
gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz
|
||||||
|
mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \
|
||||||
|
data/fbank/tedlium_cuts_train.jsonl.gz
|
||||||
|
|
||||||
touch data/fbank/.tedlium3.done
|
touch data/fbank/.tedlium3.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
@ -115,28 +128,24 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||||
log "Stage 5: Prepare phone based lang"
|
log "Stage 5: Prepare BPE train data and set of words"
|
||||||
lang_dir=data/lang_phone
|
lang_dir=data/lang
|
||||||
mkdir -p $lang_dir
|
mkdir -p $lang_dir
|
||||||
|
|
||||||
if [ ! -f $lang_dir/train.text ]; then
|
if [ ! -f $lang_dir/train.txt ]; then
|
||||||
|
gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt
|
||||||
|
|
||||||
./local/prepare_transcripts.py \
|
./local/prepare_transcripts.py \
|
||||||
--lang-dir $lang_dir \
|
--input-text-path $lang_dir/train_orig.txt \
|
||||||
--manifests-dir data/manifests
|
--output-text-path $lang_dir/train.txt
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ ! -f $lang_dir/lexicon_words.txt ]; then
|
if [ ! -f $lang_dir/words.txt ]; then
|
||||||
./local/prepare_lexicon.py \
|
|
||||||
--lang-dir $lang_dir \
|
|
||||||
--manifests-dir data/manifests
|
|
||||||
fi
|
|
||||||
|
|
||||||
(echo '!SIL SIL'; echo '<UNK> <UNK>'; ) |
|
awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic |
|
||||||
cat - $lang_dir/lexicon_words.txt |
|
sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt
|
||||||
sort | uniq > $lang_dir/lexicon.txt
|
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
./local/prepare_words.py --lang-dir $lang_dir
|
||||||
./local/prepare_lang.py --lang-dir $lang_dir
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -148,25 +157,56 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
|||||||
mkdir -p $lang_dir
|
mkdir -p $lang_dir
|
||||||
# We reuse words.txt from phone based lexicon
|
# We reuse words.txt from phone based lexicon
|
||||||
# so that the two can share G.pt later.
|
# so that the two can share G.pt later.
|
||||||
cp data/lang_phone/words.txt $lang_dir
|
cp data/lang/words.txt $lang_dir
|
||||||
|
|
||||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
|
||||||
log "Generate data for BPE training"
|
|
||||||
cat data/lang_phone/train.text |
|
|
||||||
cut -d " " -f 2- > $lang_dir/transcript_words.txt
|
|
||||||
# remove the <unk> for transcript_words.txt
|
|
||||||
sed -i 's/ <unk>//g' $lang_dir/transcript_words.txt
|
|
||||||
sed -i 's/<unk> //g' $lang_dir/transcript_words.txt
|
|
||||||
sed -i 's/<unk>//g' $lang_dir/transcript_words.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
./local/train_bpe_model.py \
|
./local/train_bpe_model.py \
|
||||||
--lang-dir $lang_dir \
|
--lang-dir $lang_dir \
|
||||||
--vocab-size $vocab_size \
|
--vocab-size $vocab_size \
|
||||||
--transcript $lang_dir/transcript_words.txt
|
--transcript data/lang/train.txt
|
||||||
|
|
||||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov "<unk>"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
|
log "Stage 7: Prepare G"
|
||||||
|
# We assume you have install kaldilm, if not, please install
|
||||||
|
# it using: pip install kaldilm
|
||||||
|
|
||||||
|
mkdir -p data/lm
|
||||||
|
if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then
|
||||||
|
# It is used in building HLG
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="data/lang/words.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=4 \
|
||||||
|
--max-arpa-warnings=-1 \
|
||||||
|
$dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then
|
||||||
|
# It is used for LM rescoring
|
||||||
|
python3 -m kaldilm \
|
||||||
|
--read-symbol-table="data/lang/words.txt" \
|
||||||
|
--disambig-symbol='#0' \
|
||||||
|
--max-order=4 \
|
||||||
|
--max-arpa-warnings=-1 \
|
||||||
|
$dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
|
log "Stage 8: Compile HLG"
|
||||||
|
|
||||||
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/HLG.pt ]; then
|
||||||
|
./local/compile_hlg.py \
|
||||||
|
--lang-dir $lang_dir \
|
||||||
|
--lm G_4_gram_small
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
@ -466,9 +466,7 @@ def one_best_decoding(
|
|||||||
Return:
|
Return:
|
||||||
An FsaVec containing linear paths.
|
An FsaVec containing linear paths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if lm_scale_list is not None:
|
if lm_scale_list is not None:
|
||||||
|
|
||||||
ans = dict()
|
ans = dict()
|
||||||
saved_am_scores = lattice.scores - lattice.lm_scores
|
saved_am_scores = lattice.scores - lattice.lm_scores
|
||||||
for lm_scale in lm_scale_list:
|
for lm_scale in lm_scale_list:
|
||||||
|
@ -112,7 +112,7 @@ def uniq_lexicon_test():
|
|||||||
# But there is no word "ca" in the lexicon, so our
|
# But there is no word "ca" in the lexicon, so our
|
||||||
# implementation returns the id of "<UNK>"
|
# implementation returns the id of "<UNK>"
|
||||||
print(token_ids, expected_token_ids)
|
print(token_ids, expected_token_ids)
|
||||||
assert token_ids.tolist() == [[sp.unk_id()]]
|
assert token_ids.tolist() == [[sp.piece_to_id("▁"), sp.unk_id()]]
|
||||||
|
|
||||||
# case 3: With OOV
|
# case 3: With OOV
|
||||||
texts = ["foo"]
|
texts = ["foo"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user