diff --git a/egs/aishell/ASR/conformer_ctc/test_transformer.py b/egs/aishell/ASR/conformer_ctc/test_transformer.py index b90215274..7c0695683 100644 --- a/egs/aishell/ASR/conformer_ctc/test_transformer.py +++ b/egs/aishell/ASR/conformer_ctc/test_transformer.py @@ -17,17 +17,16 @@ import torch +from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, + add_eos, + add_sos, + decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, ) -from torch.nn.utils.rnn import pad_sequence - def test_encoder_padding_mask(): supervisions = { @@ -82,11 +81,7 @@ def test_decoder_padding_mask(): y = pad_sequence(x, batch_first=True, padding_value=-1) mask = decoder_padding_mask(y, ignore_id=-1) expected_mask = torch.tensor( - [ - [False, False, True], - [False, True, True], - [False, False, False], - ] + [[False, False, True], [False, True, True], [False, False, False]] ) assert torch.all(torch.eq(mask, expected_mask)) diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 9dede6288..9075ecb7e 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -308,17 +308,13 @@ class AishellAsrDataModule(DataModule): @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train.json.gz" - ) + cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz") return cuts_train @lru_cache() def valid_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev.json.gz" - ) + cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz") return cuts_valid @lru_cache() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py index 8421dd3ea..b68221d08 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py @@ -29,10 +29,7 @@ import torchaudio from model import TdnnLstm from torch.nn.utils.rnn import pad_sequence -from icefall.decode import ( - get_lattice, - one_best_decoding, -) +from icefall.decode import get_lattice, one_best_decoding from icefall.utils import AttributeDict, get_texts @@ -203,7 +200,7 @@ def main(): subsampling_factor=params.subsampling_factor, ) - assert(params.method == "1best") + assert params.method == "1best" logging.info("Use HLG decoding") best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores