mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix code style
This commit is contained in:
parent
f7a26400ab
commit
73ad3e3101
@ -17,17 +17,16 @@
|
|||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from transformer import (
|
from transformer import (
|
||||||
Transformer,
|
Transformer,
|
||||||
|
add_eos,
|
||||||
|
add_sos,
|
||||||
|
decoder_padding_mask,
|
||||||
encoder_padding_mask,
|
encoder_padding_mask,
|
||||||
generate_square_subsequent_mask,
|
generate_square_subsequent_mask,
|
||||||
decoder_padding_mask,
|
|
||||||
add_sos,
|
|
||||||
add_eos,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
|
|
||||||
|
|
||||||
def test_encoder_padding_mask():
|
def test_encoder_padding_mask():
|
||||||
supervisions = {
|
supervisions = {
|
||||||
@ -82,11 +81,7 @@ def test_decoder_padding_mask():
|
|||||||
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||||
mask = decoder_padding_mask(y, ignore_id=-1)
|
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||||
expected_mask = torch.tensor(
|
expected_mask = torch.tensor(
|
||||||
[
|
[[False, False, True], [False, True, True], [False, False, False]]
|
||||||
[False, False, True],
|
|
||||||
[False, True, True],
|
|
||||||
[False, False, False],
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
assert torch.all(torch.eq(mask, expected_mask))
|
assert torch.all(torch.eq(mask, expected_mask))
|
||||||
|
|
||||||
|
@ -308,17 +308,13 @@ class AishellAsrDataModule(DataModule):
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
cuts_train = load_manifest(
|
cuts_train = load_manifest(self.args.feature_dir / "cuts_train.json.gz")
|
||||||
self.args.feature_dir / "cuts_train.json.gz"
|
|
||||||
)
|
|
||||||
return cuts_train
|
return cuts_train
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def valid_cuts(self) -> CutSet:
|
def valid_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
cuts_valid = load_manifest(
|
cuts_valid = load_manifest(self.args.feature_dir / "cuts_dev.json.gz")
|
||||||
self.args.feature_dir / "cuts_dev.json.gz"
|
|
||||||
)
|
|
||||||
return cuts_valid
|
return cuts_valid
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
|
@ -29,10 +29,7 @@ import torchaudio
|
|||||||
from model import TdnnLstm
|
from model import TdnnLstm
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
get_lattice,
|
|
||||||
one_best_decoding,
|
|
||||||
)
|
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
@ -203,7 +200,7 @@ def main():
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert(params.method == "1best")
|
assert params.method == "1best"
|
||||||
logging.info("Use HLG decoding")
|
logging.info("Use HLG decoding")
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
|
Loading…
x
Reference in New Issue
Block a user