From 543b4cc1ca45f5a6e273cb1440a233e5fc51fa36 Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 19 Oct 2023 15:53:31 +0200 Subject: [PATCH] small enhanecements (#1322) - add extra check of 'x' and 'x_lens' to earlier point in Transducer model - specify 'utf' encoding when opening text files for writing (recogs, errs) --- egs/librispeech/ASR/pruned_transducer_stateless7/model.py | 3 +++ icefall/utils.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..add0e6a18 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -114,6 +114,9 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 + # x.T_dim == max(x_len) + assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max()) + encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) diff --git a/icefall/utils.py b/icefall/utils.py index 6479d8f87..399e8d8b3 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -498,7 +498,7 @@ def store_transcripts( Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp in texts: if char_level: ref = list("".join(ref)) @@ -523,7 +523,7 @@ def store_transcripts_and_timestamps( Returns: Return None. """ - with open(filename, "w") as f: + with open(filename, "w", encoding="utf8") as f: for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f)