diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index a89d4e4a1..c474f51d5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -116,7 +116,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_transducer_model, get_chunk_info from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 7197ace17..c11b78937 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -84,8 +84,6 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - chunk_size: int = -1, - left_context_chunks: int = -1, ) -> torch.Tensor: """ Args: @@ -106,9 +104,6 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part - chunk_size, left_context_chunks: - For chunkwise causal training; will be passed to the zipformer encoder. - chunk_size is specified in frames at 50Hz, i.e. after 2x downsampling. Returns: Return the transducer loss. @@ -124,8 +119,8 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens, chunk_size=chunk_size, - left_context_chunks=left_context_chunks) + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) # Now for the decoder, i.e., the prediction network diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 3e1470fbf..2219cf350 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -1117,7 +1117,7 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): # to make the convolution causal. left_pad = self.kernel_size // 2 - if chunk_size < 0: + if chunk_size < 0 or chunk_size > seq_len: chunk_size = seq_len right_pad = -seq_len % chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 2e5b2a641..52f25ae15 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -226,20 +226,29 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( - "--chunk-size", - type=str, - default="-1", - help=" Embedding dimension in encoder stacks: a single int or comma-separated list." + "--causal", + type=str2bool, + default=True, + help="If True, use causal version of model.", ) parser.add_argument( - "--chunk-left-context-frames", + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False" + ) + + parser.add_argument( + "--left-context-frames", type=str, default="64,128,256,-1", - help="Left-contexts for chunkwise training, measured in frames (positive values must be " - "multiples of all positive elements of chunk-size). If --chunk-size is specified, " - "chunk left-context frames will be chosen randomly from this list." + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant." ) @@ -544,7 +553,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: cnn_module_kernel=to_int_tuple(params.cnn_module_kernel), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, - causal=(params.chunk_size != "-1"), + causal=params.causal, + chunk_size=to_int_tuple(params.chunk_size), + left_context_frames=to_int_tuple(params.left_context_frames), ) return encoder @@ -705,25 +716,6 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) -def get_chunk_info(params: AttributeDict) -> Tuple[int, int]: - """ - Returns chunk_size and left_context_chunks. - """ - chunk_sizes = list(map(int, params.chunk_size.split(','))) - n = len(chunk_sizes) - chunk_size = random.choice(chunk_sizes) - if chunk_size == -1: - left_context_chunks = -1 - else: - chunk_left_context_frames = list(map(int, params.chunk_left_context_frames.split(','))) - m = len(chunk_left_context_frames) - left_context_frames = random.choice(chunk_left_context_frames) - if left_context_frames != -1: - assert left_context_frames % chunk_size == 0, "Invalid --chunk-left-context-frames value" - # Note: in Python, -1 // n == -1 for n > 0 - left_context_chunks = left_context_frames // chunk_size - return chunk_size, left_context_chunks - def compute_loss( params: AttributeDict, @@ -770,8 +762,6 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) - chunk_size, left_context_chunks = get_chunk_info(params) - with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, @@ -780,8 +770,6 @@ def compute_loss( prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, - chunk_size=chunk_size, - left_context_chunks=left_context_chunks, ) s = params.simple_loss_scale diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 089b918e0..6da3566e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -102,6 +102,12 @@ class Zipformer(EncoderInterface): slightly slower and use more memory. Enables use of the chunk_size and left_context_chunk options in forward(), which simulates streaming decoding. + chunk_size: (list of int): only set this to other than [-1] if causal; + the chunk size will be randomly chosen from this list. -1 means no chunking. + left_context_frames: (list of int): determines the number of left- + context chunks for causal training; will be rounded to a number of + chunks. Must not be less than cnn_module_kernel (after factoring in + rounding and downsampling); an error will be thrown if this is violated. """ def __init__( self, @@ -122,6 +128,8 @@ class Zipformer(EncoderInterface): dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, causal: bool = False, + chunk_size: Tuple[int] = [-1], + left_context_frames: Tuple[int] = [-1], ) -> None: super(Zipformer, self).__init__() @@ -162,6 +170,10 @@ class Zipformer(EncoderInterface): feedforward_dim = _to_tuple(feedforward_dim) self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + self.causal = causal + self.chunk_size = chunk_size + self.left_context_frames = left_context_frames + for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d @@ -319,10 +331,24 @@ class Zipformer(EncoderInterface): return feature_masks + def get_chunk_info(self) -> Tuple[int, int]: + """ + Returns chunk_size and left_context_chunks. + """ + if not self.causal: + return -1, -1 + chunk_size = random.choice(self.chunk_size) + if chunk_size == -1: + left_context_chunks = -1 + else: + left_context_frames = random.choice(self.left_context_frames) + # Note: in Python, -1 // n == -1 for n > 0 + left_context_chunks = left_context_frames // chunk_size + return chunk_size, left_context_chunks + + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, - chunk_size: int = -1, - left_context_chunks: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -362,6 +388,8 @@ class Zipformer(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) + chunk_size, left_context_chunks = self.get_chunk_info() + attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) for i, module in enumerate(self.encoders): @@ -2257,6 +2285,8 @@ def _test_zipformer_main(causal: bool = False): c = Zipformer( num_features=feature_dim, encoder_dim=(64,96), encoder_unmasked_dim=(48,64), num_heads=(4,4), causal=causal, + chunk_size=(4,) if causal else (-1,), + left_context_frames=(64,) ) batch_size = 5 seq_len = 20 @@ -2264,7 +2294,6 @@ def _test_zipformer_main(causal: bool = False): f = c( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), - chunk_size=4 if causal else -1, ) f[0].sum().backward() c.eval()