diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index f996e27f4..82da8f076 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -143,7 +143,7 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, # parameters for decoder - "embedding_dim" : 512, + "embedding_dim": 512, "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 2df242f3a..c653cf3fc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -28,7 +28,8 @@ Usage: It will generate a file exp_dir/pretrained.pt -To use the generated file with `pruned_transducer_stateless/decode.py`, you can do: +To use the generated file with `pruned_transducer_stateless/decode.py`, +you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt @@ -48,6 +49,7 @@ from pathlib import Path import sentencepiece as spm import torch +import torch.nn as nn from conformer import Conformer from decoder import Decoder from joiner import Joiner diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index c3c252cd0..4243ce418 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -136,11 +136,14 @@ class Transducer(nn.Module): lm_only_scale=self.lm_scale, am_only_scale=self.am_scale, boundary=boundary, - return_grad=True + return_grad=True, ) ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=self.prune_range + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=self.prune_range, ) am_pruned, lm_pruned = k2.do_rnnt_pruning( @@ -150,7 +153,11 @@ class Transducer(nn.Module): logits = self.joiner(am_pruned, lm_pruned) pruned_loss = k2.rnnt_loss_pruned( - joint=logits, symbols=y_padded, ranges=ranges, termination_symbol=blank_id, boundary=boundary + joint=logits, + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, ) return (-torch.sum(simple_loss), -torch.sum(pruned_loss)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index f689308bf..73c5aee5c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -49,6 +49,7 @@ from typing import List import kaldifeat import sentencepiece as spm import torch +import torch.nn as nn import torchaudio from beam_search import beam_search, greedy_search from conformer import Conformer @@ -142,7 +143,7 @@ def get_params() -> AttributeDict: "num_encoder_layers": 12, "vgg_frontend": False, # parameters for decoder - "embedding_dim" : 512, + "embedding_dim": 512, "env_info": get_env_info(), } ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 2694fd3a0..abd91d33f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -116,7 +116,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp", + default="pruned_transducer_stateless/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -271,7 +271,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: return decoder -def get_joiner_model(params: AttributeDict) -> nn.Module : +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.vocab_size, inner_dim=params.embedding_dim, @@ -280,7 +280,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module : return joiner -def get_transducer_model(params: AttributeDict) ->nn.Module: +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params)