Change how chunk-size is specified

This commit is contained in:
Daniel Povey 2023-02-11 14:35:31 +08:00
parent e7e7560bba
commit 329175c897
5 changed files with 56 additions and 44 deletions

View File

@ -116,7 +116,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model, get_chunk_info
from icefall.checkpoint import (
average_checkpoints,

View File

@ -84,8 +84,6 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
chunk_size: int = -1,
left_context_chunks: int = -1,
) -> torch.Tensor:
"""
Args:
@ -106,9 +104,6 @@ class Transducer(nn.Module):
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
chunk_size, left_context_chunks:
For chunkwise causal training; will be passed to the zipformer encoder.
chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling.
Returns:
Return the transducer loss.
@ -124,8 +119,8 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size,
left_context_chunks=left_context_chunks)
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network

View File

@ -1117,7 +1117,7 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
# to make the convolution causal.
left_pad = self.kernel_size // 2
if chunk_size < 0:
if chunk_size < 0 or chunk_size > seq_len:
chunk_size = seq_len
right_pad = -seq_len % chunk_size

View File

@ -226,20 +226,29 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--chunk-size",
type=str,
default="-1",
help=" Embedding dimension in encoder stacks: a single int or comma-separated list."
"--causal",
type=str2bool,
default=True,
help="If True, use causal version of model.",
)
parser.add_argument(
"--chunk-left-context-frames",
"--chunk-size",
type=str,
default="16,32,64,-1",
help="Chunk sizes will be chosen randomly from this list during training. "
" Must be just -1 if --causal=False"
)
parser.add_argument(
"--left-context-frames",
type=str,
default="64,128,256,-1",
help="Left-contexts for chunkwise training, measured in frames (positive values must be "
"multiples of all positive elements of chunk-size). If --chunk-size is specified, "
"chunk left-context frames will be chosen randomly from this list."
help="Maximum left-contexts for causal training, measured in frames which will "
"be converted to a number of chunks. If splitting into chunks, "
"chunk left-context frames will be chosen randomly from this list; else not relevant."
)
@ -544,7 +553,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
cnn_module_kernel=to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=(params.chunk_size != "-1"),
causal=params.causal,
chunk_size=to_int_tuple(params.chunk_size),
left_context_frames=to_int_tuple(params.left_context_frames),
)
return encoder
@ -705,25 +716,6 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def get_chunk_info(params: AttributeDict) -> Tuple[int, int]:
"""
Returns chunk_size and left_context_chunks.
"""
chunk_sizes = list(map(int, params.chunk_size.split(',')))
n = len(chunk_sizes)
chunk_size = random.choice(chunk_sizes)
if chunk_size == -1:
left_context_chunks = -1
else:
chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(',')))
m = len(chunk_left_context_frames)
left_context_frames = random.choice(chunk_left_context_frames)
if left_context_frames != -1:
assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value"
# Note: in Python, -1 // n == -1 for n > 0
left_context_chunks = left_context_frames // chunk_size
return chunk_size, left_context_chunks
def compute_loss(
params: AttributeDict,
@ -770,8 +762,6 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
chunk_size, left_context_chunks = get_chunk_info(params)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
@ -780,8 +770,6 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
chunk_size=chunk_size,
left_context_chunks=left_context_chunks,
)
s = params.simple_loss_scale

View File

@ -102,6 +102,12 @@ class Zipformer(EncoderInterface):
slightly slower and use more memory. Enables use of the chunk_size and
left_context_chunk options in forward(), which simulates streaming
decoding.
chunk_size: (list of int): only set this to other than [-1] if causal;
the chunk size will be randomly chosen from this list. -1 means no chunking.
left_context_frames: (list of int): determines the number of left-
context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated.
"""
def __init__(
self,
@ -122,6 +128,8 @@ class Zipformer(EncoderInterface):
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
chunk_size: Tuple[int] = [-1],
left_context_frames: Tuple[int] = [-1],
) -> None:
super(Zipformer, self).__init__()
@ -162,6 +170,10 @@ class Zipformer(EncoderInterface):
feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.causal = causal
self.chunk_size = chunk_size
self.left_context_frames = left_context_frames
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
@ -319,10 +331,24 @@ class Zipformer(EncoderInterface):
return feature_masks
def get_chunk_info(self) -> Tuple[int, int]:
"""
Returns chunk_size and left_context_chunks.
"""
if not self.causal:
return -1, -1
chunk_size = random.choice(self.chunk_size)
if chunk_size == -1:
left_context_chunks = -1
else:
left_context_frames = random.choice(self.left_context_frames)
# Note: in Python, -1 // n == -1 for n > 0
left_context_chunks = left_context_frames // chunk_size
return chunk_size, left_context_chunks
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor,
chunk_size: int = -1,
left_context_chunks: int = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -362,6 +388,8 @@ class Zipformer(EncoderInterface):
outputs = []
feature_masks = self.get_feature_masks(x)
chunk_size, left_context_chunks = self.get_chunk_info()
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
for i, module in enumerate(self.encoders):
@ -2257,6 +2285,8 @@ def _test_zipformer_main(causal: bool = False):
c = Zipformer(
num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4),
causal=causal,
chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,)
)
batch_size = 5
seq_len = 20
@ -2264,7 +2294,6 @@ def _test_zipformer_main(causal: bool = False):
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
chunk_size=4 if causal else -1,
)
f[0].sum().backward()
c.eval()