diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 31bab122c..e21844129 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -57,14 +57,15 @@ import kaldifeat import torch import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model -from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict @@ -111,6 +112,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -137,70 +145,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - "sample_rate": 16000, - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -225,6 +169,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -279,12 +224,22 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) - hyps = [] - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + hyp_list = [] + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, ) - + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: for i in range(encoder_out.size(0)): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -301,17 +256,13 @@ def main(): encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.method}" ) - hyps.append([lexicon.token_table[i] for i in hyp]) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 698594e92..4f7da3bde 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -57,14 +57,15 @@ import kaldifeat import torch import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model -from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict @@ -111,6 +112,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -137,70 +145,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - "sample_rate": 16000, - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -225,6 +169,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -279,12 +224,22 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) - hyps = [] - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + hyp_list = [] + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, ) - + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: for i in range(encoder_out.size(0)): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -301,17 +256,13 @@ def main(): encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.method}" ) - hyps.append([lexicon.token_table[i] for i in hyp]) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index d7e52da30..815e1c02a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -232,7 +232,7 @@ def greedy_search_batch( decoder_input = torch.tensor( decoder_input, device=device, - dtype=torch.in64, + dtype=torch.int64, ) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 51154cbb4..23ccdf6ad 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -255,23 +255,26 @@ def main(): encoder_out=encoder_out, beam=params.beam_size, ) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyp_list.append(hyp) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + hyp_list.append(hyp) hyps = [sp.decode(hyp).split() for hyp in hyp_list]