Merge branch 'k2-fsa:master' into tedlium3-pruned-transducer-stateless-new

This commit is contained in:
Mingshuang Luo 2022-04-11 10:07:06 +08:00 committed by GitHub
commit 6f8a9e97ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1083 additions and 873 deletions

View File

@ -13,4 +13,5 @@ per-file-ignores =
exclude = exclude =
.git, .git,
**/data/**, **/data/**,
icefall/shared/make_kn_lm.py icefall/shared/make_kn_lm.py,
icefall/__init__.py

View File

@ -45,7 +45,9 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4
# See https://github.com/psf/black/issues/2964
# The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4
- name: Run flake8 - name: Run flake8
shell: bash shell: bash

View File

@ -4,6 +4,8 @@ repos:
hooks: hooks:
- id: black - id: black
args: [--line-length=80] args: [--line-length=80]
additional_dependencies: ['click==8.0.1']
exclude: icefall\/__init__\.py
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 3.9.2 rev: 3.9.2

View File

@ -27,9 +27,21 @@ Installation
``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and ``icefall`` depends on `k2 <https://github.com/k2-fsa/k2>`_ and
`lhotse <https://github.com/lhotse-speech/lhotse>`_. `lhotse <https://github.com/lhotse-speech/lhotse>`_.
We recommend you to install ``k2`` first, as ``k2`` is bound to We recommend you to use the following steps to install the dependencies.
a specific version of PyTorch after compilation. Install ``k2`` also
installs its dependency PyTorch, which can be reused by ``lhotse``. - (0) Install PyTorch and torchaudio
- (1) Install k2
- (2) Install lhotse
.. caution::
Installation order matters.
(0) Install PyTorch and torchaudio
----------------------------------
Please refer `<https://pytorch.org/>`_ to install PyTorch
and torchaudio.
(1) Install k2 (1) Install k2
@ -54,14 +66,15 @@ to install ``k2``.
Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_ Please refer to `<https://lhotse.readthedocs.io/en/latest/getting-started.html#installation>`_
to install ``lhotse``. to install ``lhotse``.
.. HINT::
Install ``lhotse`` also installs its dependency `torchaudio <https://github.com/pytorch/audio>`_. .. hint::
.. CAUTION:: We strongly recommend you to use::
pip install git+https://github.com/lhotse-speech/lhotse
to install the latest version of lhotse.
If you have installed ``torchaudio``, please consider uninstalling it before
installing ``lhotse``. Otherwise, it may update your already installed PyTorch.
(3) Download icefall (3) Download icefall
-------------------- --------------------

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/conformer_ctc/label_smoothing.py

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1 @@
../conformer_ctc/label_smoothing.py

View File

@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# |-- lexicon.txt # |-- lexicon.txt
# `-- speaker.info # `-- speaker.info
if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then
lhotse download aishell $dl_dir lhotse download aishell $dl_dir
fi fi

View File

@ -55,18 +55,17 @@ from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -111,6 +110,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -137,70 +143,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -225,6 +167,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -249,7 +192,7 @@ def main():
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device
@ -279,12 +222,22 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
hyps = [] encoder_out, encoder_out_lens = model.encoder(
with torch.no_grad(): x=features, x_lens=feature_lens
encoder_out, encoder_out_lens = model.encoder( )
x=features, x_lens=feature_lens hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
) )
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -301,17 +254,15 @@ def main():
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.method}" f"Unsupported decoding method: {params.method}"
) )
hyps.append([lexicon.token_table[i] for i in hyp]) hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -55,18 +55,17 @@ from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -111,6 +110,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -137,70 +143,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -225,6 +167,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -279,12 +222,22 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
hyps = [] encoder_out, encoder_out_lens = model.encoder(
with torch.no_grad(): x=features, x_lens=feature_lens
encoder_out, encoder_out_lens = model.encoder( )
x=features, x_lens=feature_lens hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
) )
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -301,17 +254,15 @@ def main():
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.method}" f"Unsupported decoding method: {params.method}"
) )
hyps.append([lexicon.token_table[i] for i in hyp]) hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -76,7 +76,11 @@ class LabelSmoothingLoss(torch.nn.Module):
target = target.clone().reshape(-1) target = target.clone().reshape(-1)
ignored = target == self.ignore_index ignored = target == self.ignore_index
target[ignored] = 0
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use target[ignored] = 0 here
target = torch.where(ignored, torch.zeros_like(target), target)
true_dist = torch.nn.functional.one_hot( true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes target, num_classes=num_classes
@ -86,8 +90,17 @@ class LabelSmoothingLoss(torch.nn.Module):
true_dist * (1 - self.label_smoothing) true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes + self.label_smoothing / num_classes
) )
# Set the value of ignored indexes to 0 # Set the value of ignored indexes to 0
true_dist[ignored] = 0 #
# See https://github.com/k2-fsa/icefall/issues/240
# and https://github.com/k2-fsa/icefall/issues/297
# for why we don't use true_dist[ignored] = 0 here
true_dist = torch.where(
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
torch.zeros_like(true_dist),
true_dist,
)
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum": if self.reduction == "sum":

View File

@ -106,7 +106,7 @@ def fast_beam_search(
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]: ) -> List[int]:
""" """Greedy search for a single utterance.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -178,6 +178,68 @@ def greedy_search(
return hyp return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass @dataclass
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
@ -304,13 +366,156 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.

View File

@ -71,6 +71,7 @@ from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search,
greedy_search, greedy_search,
greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -97,27 +98,28 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
) )
parser.add_argument( parser.add_argument(
@ -191,7 +193,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -261,6 +263,24 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -280,12 +300,6 @@ def decode_one_batch(
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -440,13 +454,19 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -472,8 +492,20 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg_last_n > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -50,7 +50,12 @@ import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -122,7 +127,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -224,28 +229,43 @@ def main():
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves): if params.method == "modified_beam_search":
# fmt: off hyp_tokens = modified_beam_search(
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] model=model,
# fmt: on encoder_out=encoder_out,
if params.method == "greedy_search": beam=params.beam_size,
hyp = greedy_search( )
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -392,12 +393,16 @@ def load_checkpoint_if_available(
"batch_idx_train", "batch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
"cur_batch_idx",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
params["start_epoch"] = saved_params["cur_epoch"] if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -492,7 +497,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -600,21 +609,6 @@ def train_one_epoch(
global_step=params.batch_idx_train, global_step=params.batch_idx_train,
) )
def maybe_log_param_relative_changes():
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
deltas = optim_step_and_measure_param_change(model, optimizer)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
else:
optimizer.step()
cur_batch_idx = params.get("cur_batch_idx", 0) cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
@ -642,7 +636,26 @@ def train_one_epoch(
maybe_log_weights("train/param_norms") maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms") maybe_log_gradients("train/grad_norms")
maybe_log_param_relative_changes()
old_parameters = None
if (
params.log_diagnostics
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
old_parameters = {
n: p.detach().clone() for n, p in model.named_parameters()
}
optimizer.step()
if old_parameters is not None:
deltas = optim_step_and_measure_param_change(model, old_parameters)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
optimizer.zero_grad() optimizer.zero_grad()
@ -783,6 +796,13 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts) num_in_total = len(train_cuts)
@ -797,7 +817,9 @@ def run(rank, world_size, args):
logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
if checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"] sampler_state_dict = checkpoints["sampler"]
else: else:
sampler_state_dict = None sampler_state_dict = None

View File

@ -23,6 +23,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -34,11 +35,20 @@ from lhotse.dataset import (
SpecAugment, SpecAugment,
) )
from lhotse.dataset.input_strategies import OnTheFlyFeatures from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriSpeechAsrDataModule: class LibriSpeechAsrDataModule:
""" """
DataModule for k2 ASR experiments. DataModule for k2 ASR experiments.
@ -301,12 +311,18 @@ class LibriSpeechAsrDataModule:
logging.info("Loading sampler state dict") logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict) train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -393,7 +394,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -397,7 +398,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -17,6 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2
import torch import torch
from model import Transducer from model import Transducer
@ -24,7 +25,7 @@ from model import Transducer
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]: ) -> List[int]:
""" """Greedy search for a single utterance.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -80,7 +81,7 @@ def greedy_search(
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) )
# logits is (1, 1, 1, vocab_size) # logits is (1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id:
@ -101,6 +102,75 @@ def greedy_search(
return hyp return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
) # (batch_size, 1, decoder_out_dim)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass @dataclass
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
@ -252,9 +322,11 @@ def run_decoder(
device = model.device device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( decoder_input = torch.tensor(
1, context_size [ys[-context_size:]],
) device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out decoder_cache[key] = decoder_out
@ -314,13 +386,158 @@ def run_joiner(
return log_prob return log_prob
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -341,12 +558,6 @@ def modified_beam_search(
device = model.device device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
B = HypothesisList() B = HypothesisList()

View File

@ -109,8 +109,11 @@ class Conformer(Transformer):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
# Caution: We assume the subsampling factor is 4! with warnings.catch_warnings():
lengths = ((x_lens - 1) // 2 - 1) // 2 warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)

View File

@ -55,14 +55,15 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -135,7 +136,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -143,70 +144,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -251,32 +188,47 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list: List[List[int]] = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if (
# fmt: off params.decoding_method == "greedy_search"
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] and params.max_sym_per_frame == 1
# fmt: on ):
if params.decoding_method == "greedy_search": hyp_list = greedy_search_batch(
hyp = greedy_search( model=model,
model=model, encoder_out=encoder_out,
encoder_out=encoder_out_i, )
max_sym_per_frame=params.max_sym_per_frame, elif params.decoding_method == "modified_beam_search":
) hyp_list = modified_beam_search(
elif params.decoding_method == "beam_search": model=model,
hyp = beam_search( encoder_out=encoder_out,
model=model, encoder_out=encoder_out_i, beam=params.beam_size beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search": else:
hyp = modified_beam_search( batch_size = encoder_out.size(0)
model=model, encoder_out=encoder_out_i, beam=params.beam_size for i in range(batch_size):
) # fmt: off
else: encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
raise ValueError( # fmt: on
f"Unsupported decoding method: {params.decoding_method}" if params.decoding_method == "greedy_search":
) hyp = greedy_search(
hyps.append(sp.decode(hyp).split()) model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -487,8 +439,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -59,17 +59,15 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -115,6 +113,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -132,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -141,70 +146,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -294,33 +235,45 @@ def main():
) )
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyp_list = []
msg = f"Using {params.method}" msg = f"Using {params.method}"
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -419,7 +420,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -22,6 +22,7 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import ( from lhotse.dataset import (
BucketingSampler, BucketingSampler,
@ -34,11 +35,20 @@ from lhotse.dataset.input_strategies import (
OnTheFlyFeatures, OnTheFlyFeatures,
PrecomputedFeatures, PrecomputedFeatures,
) )
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class AsrDataModule: class AsrDataModule:
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
@ -253,12 +263,19 @@ class AsrDataModule:
) )
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader( train_dl = DataLoader(
train, train,
sampler=train_sampler, sampler=train_sampler,
batch_size=None, batch_size=None,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
worker_init_fn=worker_init_fn,
) )
return train_dl return train_dl

View File

@ -46,15 +46,16 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
modified_beam_search,
)
from librispeech import LibriSpeech from librispeech import LibriSpeech
from model import Transducer from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -127,7 +128,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -135,71 +136,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -244,32 +180,47 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
for i in range(batch_size): if (
# fmt: off params.decoding_method == "greedy_search"
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] and params.max_sym_per_frame == 1
# fmt: on ):
if params.decoding_method == "greedy_search": hyp_list = greedy_search_batch(
hyp = greedy_search( model=model,
model=model, encoder_out=encoder_out,
encoder_out=encoder_out_i, )
max_sym_per_frame=params.max_sym_per_frame, elif params.decoding_method == "modified_beam_search":
) hyp_list = modified_beam_search(
elif params.decoding_method == "beam_search": model=model,
hyp = beam_search( encoder_out=encoder_out,
model=model, encoder_out=encoder_out_i, beam=params.beam_size beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search": else:
hyp = modified_beam_search( for i in range(batch_size):
model=model, encoder_out=encoder_out_i, beam=params.beam_size # fmt: off
) encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
else: # fmt: on
raise ValueError( if params.decoding_method == "greedy_search":
f"Unsupported decoding method: {params.decoding_method}" hyp = greedy_search(
) model=model,
hyps.append(sp.decode(hyp).split()) encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -483,8 +434,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -59,17 +59,15 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -115,6 +113,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -132,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -141,70 +146,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -294,33 +235,46 @@ def main():
) )
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyp_list = []
msg = f"Using {params.method}" msg = f"Using {params.method}"
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
import argparse import argparse
import logging import logging
import random import random
import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Optional, Tuple
@ -466,7 +467,11 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()

View File

@ -0,0 +1,65 @@
# isort:skip_file
from . import (
checkpoint,
decode,
dist,
env,
utils
)
from .checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
remove_checkpoints,
save_checkpoint,
save_checkpoint_with_global_batch_idx,
)
from .decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from .dist import (
cleanup_dist,
setup_dist,
)
from .env import (
get_env_info,
get_git_branch_name,
get_git_date,
get_git_sha1,
)
from .utils import (
AttributeDict,
MetricsTracker,
add_eos,
add_sos,
concat,
encode_supervisions,
get_alignments,
get_executor,
get_texts,
l1_norm,
l2_norm,
linf_norm,
load_alignments,
make_pad_mask,
measure_gradient_norms,
measure_weight_norms,
optim_step_and_measure_param_change,
save_alignments,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)

View File

@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx(
) )
def find_checkpoints(out_dir: Path) -> List[str]: def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]:
"""Find all available checkpoints in a directory. """Find all available checkpoints in a directory.
The checkpoint filenames have the form: `checkpoint-xxx.pt` The checkpoint filenames have the form: `checkpoint-xxx.pt`
where xxx is a numerical value. where xxx is a numerical value.
Assume you have the following checkpoints in the folder `foo`:
- checkpoint-1.pt
- checkpoint-20.pt
- checkpoint-300.pt
- checkpoint-4000.pt
Case 1 (Return all checkpoints)::
find_checkpoints(out_dir='foo')
Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e.,
checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt)
find_checkpoints(out_dir='foo', iteration=20)
Case 3 (Return checkpoints older than checkpoint-20.pt, i.e.,
checkpoint-20.pt, checkpoint-1.pt)::
find_checkpoints(out_dir='foo', iteration=-20)
Args: Args:
out_dir: out_dir:
The directory where to search for checkpoints. The directory where to search for checkpoints.
iteration:
If it is 0, return all available checkpoints.
If it is positive, return the checkpoints whose iteration number is
greater than or equal to `iteration`.
If it is negative, return the checkpoints whose iteration number is
less than or equal to `-iteration`.
Returns: Returns:
Return a list of checkpoint filenames, sorted in descending Return a list of checkpoint filenames, sorted in descending
order by the numerical value in the filename. order by the numerical value in the filename.
""" """
checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt"))
pattern = re.compile(r"checkpoint-([0-9]+).pt") pattern = re.compile(r"checkpoint-([0-9]+).pt")
idx_checkpoints = [ iter_checkpoints = [
(int(pattern.search(c).group(1)), c) for c in checkpoints (int(pattern.search(c).group(1)), c) for c in checkpoints
] ]
# iter_checkpoints is a list of tuples. Each tuple contains
# two elements: (iteration_number, checkpoint-iteration_number.pt)
iter_checkpoints = sorted(
iter_checkpoints, reverse=True, key=lambda x: x[0]
)
if iteration >= 0:
ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration]
else:
ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration]
idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0])
ans = [ic[1] for ic in idx_checkpoints]
return ans return ans

View File

@ -135,8 +135,13 @@ def get_diagnostics_for_dim(
return "" return ""
count = sum(counts) count = sum(counts)
stats = stats / count stats = stats / count
stats, _ = torch.symeig(stats) try:
stats = stats.abs().sqrt() eigs, _ = torch.symeig(stats)
stats = eigs.abs().sqrt()
except: # noqa
print("Error getting eigenvalues, trying another method.")
eigs = torch.linalg.eigvals(stats)
stats = eigs.abs().sqrt()
# sqrt so it reflects data magnitude, like stddev- not variance # sqrt so it reflects data magnitude, like stddev- not variance
elif sizes_same: elif sizes_same:
stats = torch.stack(stats).sum(dim=0) stats = torch.stack(stats).sum(dim=0)

View File

@ -95,6 +95,7 @@ def get_env_info() -> Dict[str, Any]:
"k2-git-sha1": k2.version.__git_sha1__, "k2-git-sha1": k2.version.__git_sha1__,
"k2-git-date": k2.version.__git_date__, "k2-git-date": k2.version.__git_date__,
"lhotse-version": lhotse.__version__, "lhotse-version": lhotse.__version__,
"torch-version": torch.__version__,
"torch-cuda-available": torch.cuda.is_available(), "torch-cuda-available": torch.cuda.is_available(),
"torch-cuda-version": torch.version.cuda, "torch-cuda-version": torch.version.cuda,
"python-version": sys.version[:3], "python-version": sys.version[:3],

View File

@ -25,15 +25,14 @@ from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union from typing import Dict, Iterable, List, TextIO, Tuple, Union
import k2 import k2
import k2.version import k2.version
import kaldialign import kaldialign
import torch import torch
import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from torch.cuda.amp import GradScaler import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
Pathlike = Union[str, Path] Pathlike = Union[str, Path]
@ -758,11 +757,10 @@ def measure_gradient_norms(
def optim_step_and_measure_param_change( def optim_step_and_measure_param_change(
model: nn.Module, model: nn.Module,
optimizer: torch.optim.Optimizer, old_parameters: Dict[str, nn.parameter.Parameter],
scaler: Optional[GradScaler] = None,
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Perform model weight update and measure the "relative change in parameters per minibatch." Measure the "relative change in parameters per minibatch."
It is understood as a ratio between the L2 norm of the difference between original and updates parameters, It is understood as a ratio between the L2 norm of the difference between original and updates parameters,
and the L2 norm of the original parameter. It is given by the formula: and the L2 norm of the original parameter. It is given by the formula:
@ -770,16 +768,31 @@ def optim_step_and_measure_param_change(
\begin{aligned} \begin{aligned}
\delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2}
\end{aligned} \end{aligned}
"""
param_copy = {n: p.detach().clone() for n, p in model.named_parameters()} This function is supposed to be used as follows:
if scaler:
scaler.step(optimizer) .. code-block:: python
else:
old_parameters = {
n: p.detach().clone() for n, p in model.named_parameters()
}
optimizer.step() optimizer.step()
deltas = optim_step_and_measure_param_change(old_parameters)
Args:
model: A torch.nn.Module instance.
old_parameters:
A Dict of named_parameters before optimizer.step().
Return:
A Dict containing the relative change for each parameter.
"""
relative_change = {} relative_change = {}
with torch.no_grad(): with torch.no_grad():
for n, p_new in model.named_parameters(): for n, p_new in model.named_parameters():
p_orig = param_copy[n] p_orig = old_parameters[n]
delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) delta = l2_norm(p_orig - p_new) / l2_norm(p_orig)
relative_change[n] = delta.item() relative_change[n] = delta.item()
return relative_change return relative_change

View File

@ -1,5 +1,6 @@
[tool.isort] [tool.isort]
profile = "black" profile = "black"
skip = ["icefall/__init__.py"]
[tool.black] [tool.black]
line-length = 80 line-length = 80
@ -9,4 +10,5 @@ exclude = '''
| \.github | \.github
)/ )/
| make_kn_lm.py | make_kn_lm.py
| icefall\/__init__\.py
''' '''

View File

@ -11,7 +11,7 @@ graphviz==0.19.1
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
-f https://k2-fsa.org/nightly/ k2==1.9.dev20211101+cpu.torch1.10.0 -f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0
git+https://github.com/lhotse-speech/lhotse git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11 kaldilm==1.11