From d65187ec5245457a43e352f4c0c9930ab2d98225 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 11 Jul 2024 14:45:35 +0800 Subject: [PATCH] Small fix (#1686) --- egs/librispeech/ASR/zipformer/scaling.py | 5 +++-- egs/librispeech/ASR/zipformer/train.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index e7c3f4ab1..3c7e0fa4e 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -636,8 +636,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): ) def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: - """ - Forward function. Args: + """Forward function. + + Args: x: a Tensor of shape (batch_size, channels, seq_len) chunk_size: the chunk size, in frames; does not have to divide seq_len exactly. """ diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3797de484..9b6f4a93a 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -406,7 +406,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( @@ -429,7 +429,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", + help="The scale to smooth the loss with am (output of encoder network) part.", ) parser.add_argument( @@ -848,7 +848,7 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. - warmup: a floating point value which increases throughout training; + warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ device = model.device if isinstance(model, DDP) else next(model.parameters()).device