mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
make block-size be a list
This commit is contained in:
parent
80a14f93d3
commit
6aaa971b34
@ -189,9 +189,9 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--block-size",
|
"--block-size",
|
||||||
type=int,
|
type=str,
|
||||||
default="32",
|
default="32",
|
||||||
help="Block size used in block-wise attention",
|
help="Block size used in block-wise attention; a single int or comma-separated list",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -581,7 +581,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
num_heads=_to_int_tuple(params.num_heads),
|
num_heads=_to_int_tuple(params.num_heads),
|
||||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||||
block_size=params.block_size,
|
block_size=_to_int_tuple(params.block_size),
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
|
@ -106,7 +106,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||||
pos_dim: int = 192,
|
pos_dim: int = 192,
|
||||||
block_size: int = 32,
|
block_size: Union[int, Tuple[int]] = 32,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
@ -142,7 +142,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
self.num_heads = num_heads = _to_tuple(num_heads)
|
self.num_heads = num_heads = _to_tuple(num_heads)
|
||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||||
self.block_size = block_size
|
self.block_size = block_size = _to_tuple(block_size)
|
||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@ -168,7 +168,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
feedforward_dim=feedforward_dim[i],
|
feedforward_dim=feedforward_dim[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
cnn_module_kernel=cnn_module_kernel[i],
|
cnn_module_kernel=cnn_module_kernel[i],
|
||||||
block_size=block_size // ds,
|
block_size=block_size[i],
|
||||||
causal=causal,
|
causal=causal,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
block_size=block_size // ds,
|
block_size=block_size[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user