mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix code style
This commit is contained in:
parent
f7a26400ab
commit
73ad3e3101
@ -17,17 +17,16 @@
|
||||
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformer import (
|
||||
Transformer,
|
||||
add_eos,
|
||||
add_sos,
|
||||
decoder_padding_mask,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
decoder_padding_mask,
|
||||
add_sos,
|
||||
add_eos,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
supervisions = {
|
||||
@ -82,11 +81,7 @@ def test_decoder_padding_mask():
|
||||
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, True],
|
||||
[False, True, True],
|
||||
[False, False, False],
|
||||
]
|
||||
[[False, False, True], [False, True, True], [False, False, False]]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
@ -308,17 +308,13 @@ class AishellAsrDataModule(DataModule):
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_train.json.gz"
|
||||
)
|
||||
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_dev.json.gz"
|
||||
)
|
||||
cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
|
@ -29,10 +29,7 @@ import torchaudio
|
||||
from model import TdnnLstm
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
@ -203,7 +200,7 @@ def main():
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
assert(params.method == "1best")
|
||||
assert params.method == "1best"
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
|
Loading…
x
Reference in New Issue
Block a user