diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index f4b81cfe3..c22e4ace6 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -116,7 +116,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, ) -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_model from icefall.checkpoint import ( average_checkpoints, @@ -694,7 +694,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index b996470aa..a100cbb8d 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -161,7 +161,7 @@ from typing import List, Tuple import sentencepiece as spm import torch from torch import Tensor, nn -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_model from icefall.checkpoint import ( average_checkpoints, @@ -408,7 +408,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: diff --git a/egs/librispeech/ASR/zipformer/generate_averaged_model.py b/egs/librispeech/ASR/zipformer/generate_averaged_model.py index fe29355f2..e0c7b52cb 100755 --- a/egs/librispeech/ASR/zipformer/generate_averaged_model.py +++ b/egs/librispeech/ASR/zipformer/generate_averaged_model.py @@ -44,7 +44,7 @@ import sentencepiece as spm import torch from asr_datamodule import LibriSpeechAsrDataModule -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_model from icefall.checkpoint import ( average_checkpoints_with_averaged_model, @@ -140,7 +140,7 @@ def main(): params.vocab_size = sp.get_piece_size() print("About to create model") - model = get_transducer_model(params) + model = get_model(params) if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 7fcab04ae..53f5adb8a 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple import k2 import torch @@ -37,7 +38,7 @@ class Transducer(nn.Module): joiner: nn.Module, encoder_dim: int, decoder_dim: int, - joiner_dim: int, + joiner_dim: int, # TODO: remove it vocab_size: int, ): """ @@ -69,16 +70,8 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear( - encoder_dim, - vocab_size, - initial_scale=0.25, - ) - self.simple_lm_proj = ScaledLinear( - decoder_dim, - vocab_size, - initial_scale=0.25, - ) + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_scale=0.25) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, initial_scale=0.25) def forward( self, @@ -215,3 +208,310 @@ class Transducer(nn.Module): ) return (simple_loss, pruned_loss) + + +class AsrModel(nn.Module): + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: Optional[nn.Module] = None, + joiner: Optional[nn.Module] = None, + encoder_dim: int = 384, + decoder_dim: int = 512, + vocab_size: int = 500, + use_transducer: bool = True, + use_ctc: bool = False, + ): + """A joint CTC & Transducer model. + + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + It is used when use_transducer is True. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + It is used when use_transducer is True. + use_transducer: + Whether use transducer head. Default: True. + use_ctc: + Whether use CTC head. Default: False. + """ + super().__init__() + + assert ( + use_transducer or use_ctc + ), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}" + + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder_embed = encoder_embed + self.encoder = encoder + + self.use_transducer = use_transducer + if use_transducer: + # Modules for Transducer head + assert decoder is not None + assert hasattr(decoder, "blank_id") + assert joiner is not None + + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_scale=0.25 + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, vocab_size, initial_scale=0.25 + ) + + self.use_ctc = use_ctc + if use_ctc: + # Modules for CTC head + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: torch.Tensor, + target_lengths: torch.Tensor, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + + ctc_loss = torch.nn.functional.ctc_loss( + log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) + targets=targets, + input_lengths=encoder_out_lens, + target_lengths=target_lengths, + reduction="sum", + ) + return ctc_loss + + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return simple_loss, pruned_loss + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + # Compute encoder outputs + + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(x_lens > 0) + + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=x_lens, + y=y, + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + ) + else: + simple_loss = 0 + pruned_loss = 0 + + if self.use_ctc: + # Compute CTC loss + targets = [t for tokens in y.tolist() for t in tokens] + # of shape (sum(y_lens),) + targets = torch.tensor(targets, device=x.device, dtype=torch.int64) + + ctc_loss = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=x_lens, + targets=targets, + target_lengths=y_lens, + ) + else: + ctc_loss = 0 + + return simple_loss, pruned_loss, ctc_loss diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index a4b7c2c36..b817a6b57 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -122,7 +122,7 @@ from beam_search import ( ) from icefall.utils import make_pad_mask from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_model def get_parser(): @@ -284,7 +284,7 @@ def main(): ), "left_context_frames should be one value in decoding." logging.info("Creating model") - model = get_transducer_model(params) + model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index c2d58cb1e..4a377085a 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -51,7 +51,7 @@ from streaming_beam_search import ( ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model +from train import add_model_arguments, get_params, get_model from icefall.checkpoint import ( average_checkpoints, @@ -756,7 +756,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 1f0741ba4..fe47601c2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -71,6 +71,7 @@ from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer +from model import AsrModel from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler @@ -242,6 +243,20 @@ def add_model_arguments(parser: argparse.ArgumentParser): "chunk left-context frames will be chosen randomly from this list; else not relevant." ) + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -385,6 +400,13 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + parser.add_argument( "--seed", type=int, @@ -585,21 +607,33 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner -def get_transducer_model(params: AttributeDict) -> nn.Module: +def get_model(params: AttributeDict) -> nn.Module: + assert ( + params.use_transducer or params.use_ctc + ), (f"At least one of them should be True, " + f"but gotten params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}") + encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - model = Transducer( + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(max(params.encoder_dim.split(','))), + encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, ) return model @@ -728,7 +762,7 @@ def compute_loss( is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute loss given the model and its inputs. Args: params: @@ -766,7 +800,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, ctc_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -775,22 +809,27 @@ def compute_loss( lm_scale=params.lm_scale, ) - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + loss = 0.0 - loss = ( - simple_loss_scale * simple_loss + - pruned_loss_scale * pruned_loss - ) + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += ( + simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss assert loss.requires_grad == is_training @@ -803,8 +842,11 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info @@ -1081,10 +1123,13 @@ def run(rank, world_size, args): params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}")