make block-size be a list

This commit is contained in:
yaozengwei 2023-07-21 11:34:19 +08:00
parent 80a14f93d3
commit 6aaa971b34
2 changed files with 7 additions and 7 deletions

View File

@ -189,9 +189,9 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--block-size",
type=int,
type=str,
default="32",
help="Block size used in block-wise attention",
help="Block size used in block-wise attention; a single int or comma-separated list",
)
parser.add_argument(
@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
block_size=params.block_size,
block_size=_to_int_tuple(params.block_size),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=params.causal,

View File

@ -106,7 +106,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192,
block_size: int = 32,
block_size: Union[int, Tuple[int]] = 32,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
@ -142,7 +142,7 @@ class Zipformer2(EncoderInterface):
self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.block_size = block_size
self.block_size = block_size = _to_tuple(block_size)
self.causal = causal
self.chunk_size = chunk_size
@ -168,7 +168,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim=feedforward_dim[i],
dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i],
block_size=block_size // ds,
block_size=block_size[i],
causal=causal,
)
@ -178,7 +178,7 @@ class Zipformer2(EncoderInterface):
encoder_layer,
num_encoder_layers[i],
pos_dim=pos_dim,
block_size=block_size // ds,
block_size=block_size[i],
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),