From 76a7bf13b2da917d8b667e07f40954d71fb8b3ef Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 12 Oct 2021 11:05:03 +0800 Subject: [PATCH] Fix style issues. --- egs/librispeech/ASR/conformer_ctc/export.py | 1 + egs/librispeech/ASR/conformer_ctc/transformer.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index b71385417..8241c84c1 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -139,6 +139,7 @@ def main(): model.load_state_dict(average_checkpoints(filenames)) model.to("cpu") + model.eval() if params.jit: logging.info("Using torch.jit.script") diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index f7467cce1..a2e36a41e 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -265,11 +265,15 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -333,11 +337,15 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)