Merge branch 'k2-fsa:master' into gigaspeech_recipe

This commit is contained in:
Wang, Guanbo 2022-04-06 20:57:51 -04:00 committed by GitHub
commit 3e5436e5d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1204 additions and 670 deletions

View File

@ -14,4 +14,5 @@ per-file-ignores =
exclude = exclude =
.git, .git,
**/data/**, **/data/**,
icefall/shared/make_kn_lm.py icefall/shared/make_kn_lm.py,
icefall/__init__.py

View File

@ -45,7 +45,9 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8 - name: Run flake8
shell: bash shell: bash

View File

@ -27,9 +27,21 @@ Installation
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and ``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_. `lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to install ``k2`` first, as ``k2`` is bound to We recommend you to use the following steps to install the dependencies.
a specific version of PyTorch after compilation. Install ``k2`` also
installs its dependency PyTorch, which can be reused by ``lhotse``. - (0) Install PyTorch and torchaudio
- (1) Install k2
- (2) Install lhotse
.. caution::
Installation order matters.
(0) Install PyTorch and torchaudio
----------------------------------
Please refer `<https://pytorch.org/>`_ to install PyTorch
and torchaudio.
(1) Install k2 (1) Install k2
@ -54,14 +66,15 @@ to install ``k2``.
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_ Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
to install ``lhotse``. to install ``lhotse``.
.. HINT::
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_. .. hint::
.. CAUTION:: We strongly recommend you to use::
pip install git+https://github.com/lhotse-speech/lhotse
to install the latest version of lhotse.
If you have installed ``torchaudio``, please consider uninstalling it before
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
(3) Download icefall (3) Download icefall
-------------------- --------------------

View File

@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# |-- lexicon.txt # |-- lexicon.txt
# `-- speaker.info # `-- speaker.info
if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then
lhotse download aishell $dl_dir lhotse download aishell $dl_dir
fi fi

View File

@ -55,18 +55,17 @@ from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -111,6 +110,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -137,70 +143,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -225,6 +167,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -249,7 +192,7 @@ def main():
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device
@ -279,12 +222,22 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
hyps = [] encoder_out, encoder_out_lens = model.encoder(
with torch.no_grad(): x=features, x_lens=feature_lens
encoder_out, encoder_out_lens = model.encoder( )
x=features, x_lens=feature_lens hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
) )
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -301,17 +254,15 @@ def main():
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.method}" f"Unsupported decoding method: {params.method}"
) )
hyps.append([lexicon.token_table[i] for i in hyp]) hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -55,18 +55,17 @@ from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -111,6 +110,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -137,70 +143,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -225,6 +167,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -279,12 +222,22 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
hyps = [] encoder_out, encoder_out_lens = model.encoder(
with torch.no_grad(): x=features, x_lens=feature_lens
encoder_out, encoder_out_lens = model.encoder( )
x=features, x_lens=feature_lens hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
) )
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -301,17 +254,15 @@ def main():
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.method}" f"Unsupported decoding method: {params.method}"
) )
hyps.append([lexicon.token_table[i] for i in hyp]) hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -17,14 +17,96 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
# current_encoder_out is of shape
# (shape.NumElements(), 1, encoder_out_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
)
logits = logits.squeeze(1).squeeze(1)
log_probs = logits.log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]: ) -> List[int]:
""" """Greedy search for a single utterance.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -96,6 +178,68 @@ def greedy_search(
return hyp return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass @dataclass
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
@ -222,13 +366,156 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.

View File

@ -42,6 +42,17 @@ Usage:
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -49,13 +60,20 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -80,27 +98,28 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
) )
parser.add_argument( parser.add_argument(
@ -125,6 +144,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -132,8 +152,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -146,7 +193,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -159,6 +206,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -181,6 +229,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -199,36 +250,74 @@ def decode_one_batch(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyps = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if params.decoding_method == "fast_beam_search":
# fmt: off hyp_tokens = fast_beam_search(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on decoding_graph=decoding_graph,
if params.decoding_method == "greedy_search": encoder_out=encoder_out,
hyp = greedy_search( encoder_out_lens=encoder_out_lens,
model=model, beam=params.beam,
encoder_out=encoder_out_i, max_contexts=params.max_contexts,
max_sym_per_frame=params.max_sym_per_frame, max_states=params.max_states,
) )
elif params.decoding_method == "beam_search": for hyp in sp.decode(hyp_tokens):
hyp = beam_search( hyps.append(hyp.split())
model=model, encoder_out=encoder_out_i, beam=params.beam_size elif (
) params.decoding_method == "greedy_search"
elif params.decoding_method == "modified_beam_search": and params.max_sym_per_frame == 1
hyp = modified_beam_search( ):
model=model, encoder_out=encoder_out_i, beam=params.beam_size hyp_tokens = greedy_search_batch(
) model=model,
else: encoder_out=encoder_out,
raise ValueError( )
f"Unsupported decoding method: {params.decoding_method}" for hyp in sp.decode(hyp_tokens):
) hyps.append(hyp.split())
hyps.append(sp.decode(hyp).split()) elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -236,6 +325,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -248,6 +338,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -275,6 +368,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -355,13 +449,24 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
if "beam_search" in params.decoding_method: params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix += f"-beam-{params.beam_size}" else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -387,8 +492,20 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg_last_n > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
@ -408,6 +525,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +550,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -61,6 +61,7 @@ class Decoder(nn.Module):
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1: if context_size > 1:
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=embedding_dim, in_channels=embedding_dim,

View File

@ -50,7 +50,12 @@ import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -122,7 +127,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -224,28 +229,43 @@ def main():
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves): if params.method == "modified_beam_search":
# fmt: off hyp_tokens = modified_beam_search(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on encoder_out=encoder_out,
if params.method == "greedy_search": beam=params.beam_size,
hyp = greedy_search( )
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -392,12 +393,16 @@ def load_checkpoint_if_available(
"batch_idx_train", "batch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
"cur_batch_idx",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
params["start_epoch"] = saved_params["cur_epoch"] if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -492,7 +497,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -783,6 +792,13 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts) num_in_total = len(train_cuts)
@ -797,7 +813,9 @@ def run(rank, world_size, args):
logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
if checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"] sampler_state_dict = checkpoints["sampler"]
else: else:
sampler_state_dict = None sampler_state_dict = None

View File

@ -23,6 +23,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -34,11 +35,20 @@ from lhotse.dataset import (
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule: class LibriSpeechAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
@ -301,12 +311,18 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict") logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict) train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -393,7 +394,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -397,7 +398,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -17,6 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2
import torch import torch
from model import Transducer from model import Transducer
@ -24,7 +25,7 @@ from model import Transducer
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]: ) -> List[int]:
""" """Greedy search for a single utterance.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -80,7 +81,7 @@ def greedy_search(
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) )
# logits is (1, 1, 1, vocab_size) # logits is (1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id:
@ -101,6 +102,75 @@ def greedy_search(
return hyp return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
) # (batch_size, 1, decoder_out_dim)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass @dataclass
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
@ -252,9 +322,11 @@ def run_decoder(
device = model.device device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( decoder_input = torch.tensor(
1, context_size [ys[-context_size:]],
) device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out decoder_cache[key] = decoder_out
@ -314,13 +386,158 @@ def run_joiner(
return log_prob return log_prob
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -341,12 +558,6 @@ def modified_beam_search(
device = model.device device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
B = HypothesisList() B = HypothesisList()

View File

@ -109,8 +109,11 @@ class Conformer(Transformer):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! with warnings.catch_warnings():
lengths = ((x_lens - 1) // 2 - 1) // 2 warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)

View File

@ -55,14 +55,15 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -135,7 +136,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -143,70 +144,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -251,32 +188,47 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list: List[List[int]] = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if (
# fmt: off params.decoding_method == "greedy_search"
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] and params.max_sym_per_frame == 1
# fmt: on ):
if params.decoding_method == "greedy_search": hyp_list = greedy_search_batch(
hyp = greedy_search( model=model,
model=model, encoder_out=encoder_out,
encoder_out=encoder_out_i, )
max_sym_per_frame=params.max_sym_per_frame, elif params.decoding_method == "modified_beam_search":
) hyp_list = modified_beam_search(
elif params.decoding_method == "beam_search": model=model,
hyp = beam_search( encoder_out=encoder_out,
model=model, encoder_out=encoder_out_i, beam=params.beam_size beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search": else:
hyp = modified_beam_search( batch_size = encoder_out.size(0)
model=model, encoder_out=encoder_out_i, beam=params.beam_size for i in range(batch_size):
) # fmt: off
else: encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
raise ValueError( # fmt: on
f"Unsupported decoding method: {params.decoding_method}" if params.decoding_method == "greedy_search":
) hyp = greedy_search(
hyps.append(sp.decode(hyp).split()) model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -487,8 +439,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -59,17 +59,15 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -115,6 +113,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -132,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -141,70 +146,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -294,33 +235,45 @@ def main():
) )
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyp_list = []
msg = f"Using {params.method}" msg = f"Using {params.method}"
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -419,7 +420,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -22,6 +22,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -34,11 +35,20 @@ from lhotse.dataset.input_strategies import (
OnTheFlyFeatures, OnTheFlyFeatures,
PrecomputedFeatures, PrecomputedFeatures,
) )
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule: class AsrDataModule:
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
@ -253,12 +263,19 @@ class AsrDataModule:
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl

View File

@ -46,15 +46,16 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
modified_beam_search,
)
from librispeech import LibriSpeech from librispeech import LibriSpeech
from model import Transducer from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -127,7 +128,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -135,71 +136,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -244,32 +180,47 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
for i in range(batch_size): if (
# fmt: off params.decoding_method == "greedy_search"
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] and params.max_sym_per_frame == 1
# fmt: on ):
if params.decoding_method == "greedy_search": hyp_list = greedy_search_batch(
hyp = greedy_search( model=model,
model=model, encoder_out=encoder_out,
encoder_out=encoder_out_i, )
max_sym_per_frame=params.max_sym_per_frame, elif params.decoding_method == "modified_beam_search":
) hyp_list = modified_beam_search(
elif params.decoding_method == "beam_search": model=model,
hyp = beam_search( encoder_out=encoder_out,
model=model, encoder_out=encoder_out_i, beam=params.beam_size beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search": else:
hyp = modified_beam_search( for i in range(batch_size):
model=model, encoder_out=encoder_out_i, beam=params.beam_size # fmt: off
) encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
else: # fmt: on
raise ValueError( if params.decoding_method == "greedy_search":
f"Unsupported decoding method: {params.decoding_method}" hyp = greedy_search(
) model=model,
hyps.append(sp.decode(hyp).split()) encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -483,8 +434,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -59,17 +59,15 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -115,6 +113,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -132,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -141,70 +146,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -294,33 +235,46 @@ def main():
) )
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyp_list = []
msg = f"Using {params.method}" msg = f"Using {params.method}"
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import random import random
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -466,7 +467,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -0,0 +1,55 @@
from .checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
remove_checkpoints,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
from .decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from .dist import (
cleanup_dist,
setup_dist,
)
from .env import (
get_env_info,
get_git_branch_name,
get_git_date,
get_git_sha1,
)
from .utils import (
AttributeDict,
MetricsTracker,
add_eos,
add_sos,
concat,
encode_supervisions,
get_alignments,
get_executor,
get_texts,
l1_norm,
l2_norm,
linf_norm,
load_alignments,
make_pad_mask,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
save_alignments,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)

View File

@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx(
) )
def find_checkpoints(out_dir: Path) -> List[str]: def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
"""Find all available checkpoints in a directory. """Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt` The checkpoint filenames have the form: `checkpoint-xxx.pt`
where xxx is a numerical value. where xxx is a numerical value.
Assume you have the following checkpoints in the folder `foo`:
- checkpoint-1.pt
- checkpoint-20.pt
- checkpoint-300.pt
- checkpoint-4000.pt
Case 1 (Return all checkpoints)::
find_checkpoints(out_dir='foo')
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
find_checkpoints(out_dir='foo', iteration=20)
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
checkpoint-20.pt, checkpoint-1.pt)::
find_checkpoints(out_dir='foo', iteration=-20)
Args: Args:
out_dir: out_dir:
The directory where to search for checkpoints. The directory where to search for checkpoints.
iteration:
If it is 0, return all available checkpoints.
If it is positive, return the checkpoints whose iteration number is
greater than or equal to `iteration`.
If it is negative, return the checkpoints whose iteration number is
less than or equal to `-iteration`.
Returns: Returns:
Return a list of checkpoint filenames, sorted in descending Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename. order by the numerical value in the filename.
""" """
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt") pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [ iter_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints (int(pattern.search(c).group(1)), c) for c in checkpoints
] ]
# iter_checkpoints is a list of tuples. Each tuple contains
# two elements: (iteration_number, checkpoint-iteration_number.pt)
iter_checkpoints = sorted(
iter_checkpoints, reverse=True, key=lambda x: x[0]
)
if iteration >= 0:
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
else:
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
ans = [ic[1] for ic in idx_checkpoints]
return ans return ans

View File

@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
return "" return ""
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
stats, _ = torch.symeig(stats) try:
stats = stats.abs().sqrt() eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same: elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)

View File

@ -1,5 +1,6 @@
[tool.isort] [tool.isort]
profile = "black" profile = "black"
skip = ["icefall/__init__.py"]
[tool.black] [tool.black]
line-length = 80 line-length = 80

View File

@ -11,7 +11,7 @@ graphviz==0.19.1
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
-f https://k2-fsa.org/nightly/ k2==1.9.dev20211101+cpu.torch1.10.0 -f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0
git+https://github.com/lhotse-speech/lhotse git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11 kaldilm==1.11