Fix code style

This commit is contained in:
pkufool 2021-11-17 19:16:39 +08:00
parent f7a26400ab
commit 73ad3e3101
3 changed files with 9 additions and 21 deletions

View File

@ -17,17 +17,16 @@
import torch import torch
from torch.nn.utils.rnn import pad_sequence
from transformer import ( from transformer import (
Transformer, Transformer,
add_eos,
add_sos,
decoder_padding_mask,
encoder_padding_mask, encoder_padding_mask,
generate_square_subsequent_mask, generate_square_subsequent_mask,
decoder_padding_mask,
add_sos,
add_eos,
) )
from torch.nn.utils.rnn import pad_sequence
def test_encoder_padding_mask(): def test_encoder_padding_mask():
supervisions = { supervisions = {
@ -82,11 +81,7 @@ def test_decoder_padding_mask():
y = pad_sequence(x, batch_first=True, padding_value=-1) y = pad_sequence(x, batch_first=True, padding_value=-1)
mask = decoder_padding_mask(y, ignore_id=-1) mask = decoder_padding_mask(y, ignore_id=-1)
expected_mask = torch.tensor( expected_mask = torch.tensor(
[ [[False, False, True], [False, True, True], [False, False, False]]
[False, False, True],
[False, True, True],
[False, False, False],
]
) )
assert torch.all(torch.eq(mask, expected_mask)) assert torch.all(torch.eq(mask, expected_mask))

View File

@ -308,17 +308,13 @@ class AishellAsrDataModule(DataModule):
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
cuts_train = load_manifest( cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
self.args.feature_dir / "cuts_train.json.gz"
)
return cuts_train return cuts_train
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts") logging.info("About to get dev cuts")
cuts_valid = load_manifest( cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
self.args.feature_dir / "cuts_dev.json.gz"
)
return cuts_valid return cuts_valid
@lru_cache() @lru_cache()

View File

@ -29,10 +29,7 @@ import torchaudio
from model import TdnnLstm from model import TdnnLstm
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.decode import ( from icefall.decode import get_lattice, one_best_decoding
get_lattice,
one_best_decoding,
)
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_texts
@ -203,7 +200,7 @@ def main():
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
assert(params.method == "1best") assert params.method == "1best"
logging.info("Use HLG decoding") logging.info("Use HLG decoding")
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores