Merge changes from master

This commit is contained in:
Daniel Povey 2022-03-24 12:59:36 +08:00
commit 5d9dae3064
12 changed files with 773 additions and 605 deletions

View File

@ -55,18 +55,17 @@ from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -111,6 +110,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -137,70 +143,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -225,6 +167,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -249,7 +192,7 @@ def main():
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device
@ -279,12 +222,22 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
hyps = []
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens x=features, x_lens=feature_lens
) )
hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -301,16 +254,14 @@ def main():
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.method}" f"Unsupported decoding method: {params.method}"
) )
hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp]) hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"

View File

@ -55,18 +55,17 @@ from typing import List
import kaldifeat import kaldifeat
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -111,6 +110,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -137,70 +143,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
"sample_rate": 16000,
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -225,6 +167,7 @@ def read_sound_files(
return ans return ans
@torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -279,12 +222,22 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10) features, batch_first=True, padding_value=math.log(1e-10)
) )
hyps = []
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lens x=features, x_lens=feature_lens
) )
hyp_list = []
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -301,16 +254,14 @@ def main():
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.method}" f"Unsupported decoding method: {params.method}"
) )
hyp_list.append(hyp)
hyps = []
for hyp in hyp_list:
hyps.append([lexicon.token_table[i] for i in hyp]) hyps.append([lexicon.token_table[i] for i in hyp])
s = "\n" s = "\n"

View File

@ -106,7 +106,7 @@ def fast_beam_search(
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]: ) -> List[int]:
""" """Greedy search for a single utterance.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -178,6 +178,68 @@ def greedy_search(
return hyp return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass @dataclass
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
@ -304,13 +366,156 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.

View File

@ -71,6 +71,7 @@ from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search,
greedy_search, greedy_search,
greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -191,7 +192,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -261,6 +262,24 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -280,12 +299,6 @@ def decode_one_batch(
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"

View File

@ -50,7 +50,12 @@ import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -122,7 +127,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -224,6 +229,23 @@ def main():
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -236,11 +258,9 @@ def main():
) )
elif params.method == "beam_search": elif params.method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
) encoder_out=encoder_out_i,
elif params.method == "modified_beam_search": beam=params.beam_size,
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")

View File

@ -392,13 +392,17 @@ def load_checkpoint_if_available(
"batch_idx_train", "batch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
"cur_batch_idx",
] ]
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"] params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -783,6 +787,13 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts) num_in_total = len(train_cuts)
@ -797,7 +808,9 @@ def run(rank, world_size, args):
logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"After removing short and long utterances: {num_left}")
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
if checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"] sampler_state_dict = checkpoints["sampler"]
else: else:
sampler_state_dict = None sampler_state_dict = None

View File

@ -17,6 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
import k2
import torch import torch
from model import Transducer from model import Transducer
@ -24,7 +25,7 @@ from model import Transducer
def greedy_search( def greedy_search(
model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int
) -> List[int]: ) -> List[int]:
""" """Greedy search for a single utterance.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -80,7 +81,7 @@ def greedy_search(
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) )
# logits is (1, 1, 1, vocab_size) # logits is (1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y != blank_id:
@ -101,6 +102,75 @@ def greedy_search(
return hyp return hyp
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_out: (batch_size, 1, decoder_out_dim)
encoder_out_len = torch.ones(batch_size, dtype=torch.int32)
decoder_out_len = torch.ones(batch_size, dtype=torch.int32)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out, decoder_out, encoder_out_len, decoder_out_len
) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
decoder_out = model.decoder(
decoder_input,
need_pad=False,
) # (batch_size, 1, decoder_out_dim)
ans = [h[context_size:] for h in hyps]
return ans
@dataclass @dataclass
class Hypothesis: class Hypothesis:
# The predicted tokens so far. # The predicted tokens so far.
@ -252,9 +322,11 @@ def run_decoder(
device = model.device device = model.device
decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( decoder_input = torch.tensor(
1, context_size [ys[-context_size:]],
) device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_cache[key] = decoder_out decoder_cache[key] = decoder_out
@ -314,13 +386,158 @@ def run_joiner(
return log_prob return log_prob
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcodded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
# current_encoder_out's shape is: (batch_size, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_hyps, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -341,12 +558,6 @@ def modified_beam_search(
device = model.device device = model.device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) T = encoder_out.size(1)
B = HypothesisList() B = HypothesisList()

View File

@ -55,14 +55,15 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -135,7 +136,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -143,70 +144,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -251,9 +188,24 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list: List[List[int]] = []
batch_size = encoder_out.size(0)
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
batch_size = encoder_out.size(0)
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -266,17 +218,17 @@ def decode_one_batch(
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
) encoder_out=encoder_out_i,
elif params.decoding_method == "modified_beam_search": beam=params.beam_size,
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -487,8 +439,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -59,17 +59,15 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -115,6 +113,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -132,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -141,70 +146,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -294,11 +235,24 @@ def main():
) )
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyp_list = []
msg = f"Using {params.method}" msg = f"Using {params.method}"
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -311,16 +265,15 @@ def main():
) )
elif params.method == "beam_search": elif params.method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
) encoder_out=encoder_out_i,
elif params.method == "modified_beam_search": beam=params.beam_size,
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps.append(sp.decode(hyp).split()) hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -46,15 +46,16 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
modified_beam_search,
)
from librispeech import LibriSpeech from librispeech import LibriSpeech
from model import Transducer from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -127,7 +128,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -135,71 +136,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -244,9 +180,24 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list = []
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
if (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -259,17 +210,17 @@ def decode_one_batch(
) )
elif params.decoding_method == "beam_search": elif params.decoding_method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
) encoder_out=encoder_out_i,
elif params.decoding_method == "modified_beam_search": beam=params.beam_size,
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyps.append(sp.decode(hyp).split()) hyp_list.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -483,8 +434,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -59,17 +59,15 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search from beam_search import (
from conformer import Conformer beam_search,
from decoder import Decoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser(): def get_parser():
@ -115,6 +113,13 @@ def get_parser():
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
) )
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument( parser.add_argument(
"--beam-size", "--beam-size",
type=int, type=int,
@ -132,7 +137,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. Used only when help="""Maximum number of symbols per frame. Used only when
--method is greedy_search. --method is greedy_search.
""", """,
@ -141,70 +146,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files( def read_sound_files(
filenames: List[str], expected_sample_rate: float filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
@ -294,11 +235,25 @@ def main():
) )
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyp_list = []
msg = f"Using {params.method}" msg = f"Using {params.method}"
if params.method == "beam_search": if params.method == "beam_search":
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
if params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_list = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
elif params.method == "modified_beam_search":
hyp_list = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -311,16 +266,15 @@ def main():
) )
elif params.method == "beam_search": elif params.method == "beam_search":
hyp = beam_search( hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size model=model,
) encoder_out=encoder_out_i,
elif params.method == "modified_beam_search": beam=params.beam_size,
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
) )
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps.append(sp.decode(hyp).split()) hyps = [sp.decode(hyp).split() for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -11,7 +11,7 @@ graphviz==0.19.1
-f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu
-f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu
-f https://k2-fsa.org/nightly/ k2==1.9.dev20211101+cpu.torch1.10.0 -f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0
git+https://github.com/lhotse-speech/lhotse git+https://github.com/lhotse-speech/lhotse
kaldilm==1.11 kaldilm==1.11