make block-size be a list

This commit is contained in:
yaozengwei 2023-07-21 11:34:19 +08:00
parent 80a14f93d3
commit 6aaa971b34
2 changed files with 7 additions and 7 deletions

View File

@ -189,9 +189,9 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--block-size", "--block-size",
type=int, type=str,
default="32", default="32",
help="Block size used in block-wise attention", help="Block size used in block-wise attention; a single int or comma-separated list",
) )
parser.add_argument( parser.add_argument(
@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_heads=_to_int_tuple(params.num_heads), num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim), feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
block_size=params.block_size, block_size=_to_int_tuple(params.block_size),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0, warmup_batches=4000.0,
causal=params.causal, causal=params.causal,

View File

@ -106,7 +106,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31, cnn_module_kernel: Union[int, Tuple[int]] = 31,
pos_dim: int = 192, pos_dim: int = 192,
block_size: int = 32, block_size: Union[int, Tuple[int]] = 32,
dropout: FloatLike = None, # see code below for default dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
causal: bool = False, causal: bool = False,
@ -142,7 +142,7 @@ class Zipformer2(EncoderInterface):
self.num_heads = num_heads = _to_tuple(num_heads) self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.block_size = block_size self.block_size = block_size = _to_tuple(block_size)
self.causal = causal self.causal = causal
self.chunk_size = chunk_size self.chunk_size = chunk_size
@ -168,7 +168,7 @@ class Zipformer2(EncoderInterface):
feedforward_dim=feedforward_dim[i], feedforward_dim=feedforward_dim[i],
dropout=dropout, dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i], cnn_module_kernel=cnn_module_kernel[i],
block_size=block_size // ds, block_size=block_size[i],
causal=causal, causal=causal,
) )
@ -178,7 +178,7 @@ class Zipformer2(EncoderInterface):
encoder_layer, encoder_layer,
num_encoder_layers[i], num_encoder_layers[i],
pos_dim=pos_dim, pos_dim=pos_dim,
block_size=block_size // ds, block_size=block_size[i],
dropout=dropout, dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),