diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 5a8dae619..eabed65fb 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -189,9 +189,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--block-size", - type=int, + type=str, default="32", - help="Block size used in block-wise attention", + help="Block size used in block-wise attention; a single int or comma-separated list", ) parser.add_argument( @@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - block_size=params.block_size, + block_size=_to_int_tuple(params.block_size), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=params.causal, diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index d12b9f22b..0ca0fcaa4 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -106,7 +106,7 @@ class Zipformer2(EncoderInterface): feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, pos_dim: int = 192, - block_size: int = 32, + block_size: Union[int, Tuple[int]] = 32, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, @@ -142,7 +142,7 @@ class Zipformer2(EncoderInterface): self.num_heads = num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) - self.block_size = block_size + self.block_size = block_size = _to_tuple(block_size) self.causal = causal self.chunk_size = chunk_size @@ -168,7 +168,7 @@ class Zipformer2(EncoderInterface): feedforward_dim=feedforward_dim[i], dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], - block_size=block_size // ds, + block_size=block_size[i], causal=causal, ) @@ -178,7 +178,7 @@ class Zipformer2(EncoderInterface): encoder_layer, num_encoder_layers[i], pos_dim=pos_dim, - block_size=block_size // ds, + block_size=block_size[i], dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),