diff --git a/egs/librispeech/ASR/zipformer/frame_reducer.py b/egs/librispeech/ASR/zipformer/frame_reducer.py new file mode 100644 index 000000000..099e76e22 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/frame_reducer.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + +NON_BLANK_THRES = 0.9 + + +class FrameReducer(nn.Module): + """The encoder output is first used to calculate + the CTC posterior probability; then for each output frame, + if its blank posterior is bigger than some thresholds, + it will be simply discarded from the encoder output. + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + ctc_output: torch.Tensor, + y_lens: Optional[torch.Tensor] = None, + blank_id: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The shared encoder output with shape [N, T, C]. + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + ctc_output: + The CTC output with shape [N, T, vocab_size]. + y_lens: + A tensor of shape (batch_size,) containing the number of frames in + `y` before padding. + blank_id: + The blank id of ctc_output. + Returns: + out: + The frame reduced encoder output with shape [N, T', C]. + out_lens: + A tensor of shape (batch_size,) containing the number of frames in + `out` before padding. + """ + N, T, C = x.size() + + padding_mask = make_pad_mask(x_lens) + non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * ( + ~padding_mask + ) + + if y_lens is not None or self.training is False: + # Limit the maximum number of reduced frames + if y_lens is not None: + limit_lens = T - y_lens + else: + # In eval mode, ensure audio that is completely silent does not make any errors + limit_lens = T - torch.ones_like(x_lens) + max_limit_len = limit_lens.max().int() + fake_limit_indexes = torch.topk( + ctc_output[:, :, blank_id], max_limit_len + ).indices + _T = ( + torch.arange(max_limit_len) + .expand_as( + fake_limit_indexes, + ) + .to(device=x.device) + ) + _T = torch.remainder(_T, limit_lens.unsqueeze(1)) + limit_indexes = torch.gather(fake_limit_indexes, 1, _T) + limit_mask = ( + torch.full_like( + non_blank_mask, + 0, + device=x.device, + ).scatter_(1, limit_indexes, 1) + == 1 + ) + + non_blank_mask = non_blank_mask | ~limit_mask + + out_lens = non_blank_mask.sum(dim=1) + max_len = out_lens.max() + pad_lens_list = ( + torch.full_like( + out_lens, + max_len.item(), + device=x.device, + ) + - out_lens + ) + max_pad_len = int(pad_lens_list.max().item()) + + out = F.pad(x, (0, 0, 0, max_pad_len)) + + valid_pad_mask = ~make_pad_mask(pad_lens_list) + total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1) + + out = out[total_valid_mask].reshape(N, -1, C) + + return out, out_lens + + +if __name__ == "__main__": + import time + + test_times = 10000 + device = "cuda:0" + frame_reducer = FrameReducer() + + # non zero case + x = torch.ones(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.log( + torch.randn(15, 498, 500, dtype=torch.float32, device=device), + ) + + avg_time = 0 + for i in range(test_times): + torch.cuda.synchronize(device=x.device) + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) + + # all zero case + x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device) + x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device) + y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device) + ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device) + + avg_time = 0 + for i in range(test_times): + torch.cuda.synchronize(device=x.device) + delta_time = time.time() + x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens) + torch.cuda.synchronize(device=x.device) + delta_time = time.time() - delta_time + avg_time += delta_time + print(x_fr.shape) + print(x_lens_fr) + print(avg_time / test_times) diff --git a/egs/librispeech/ASR/zipformer/lconv.py b/egs/librispeech/ASR/zipformer/lconv.py new file mode 100644 index 000000000..a60cb743e --- /dev/null +++ b/egs/librispeech/ASR/zipformer/lconv.py @@ -0,0 +1,113 @@ +# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn +from acoustic_model.utils_py.scaling_zipformer import Balancer, ScaledConv1d + + +class LConv(nn.Module): + """A convolution module to prevent information loss.""" + + def __init__( + self, + channels: int, + kernel_size: int = 7, + bias: bool = True, + ): + """ + Args: + channels: + Dimension of the input embedding, and of the lconv output. + """ + super().__init__() + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + self.deriv_balancer1 = Balancer( + 2 * channels, + channel_dim=1, + min_abs=0.05, + max_abs=10.0, + min_positive=0.05, + max_positive=1.0, + ) + + self.depthwise_conv = nn.Conv1d( + 2 * channels, + 2 * channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=2 * channels, + bias=bias, + ) + + self.deriv_balancer2 = Balancer( + 2 * channels, + channel_dim=1, + min_positive=0.05, + max_positive=1.0, + min_abs=0.05, + max_abs=20.0, + ) + + self.pointwise_conv2 = ScaledConv1d( + 2 * channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.05, + ) + + def forward( + self, + x: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + x: A 3-D tensor of shape (N, T, C). + Returns: + Return a tensor of shape (N, T, C). + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(0, 2, 1) # (#batch, channels, time). + + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + + if src_key_padding_mask is not None: + x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + + x = self.pointwise_conv2(x) # (batch, channels, time) + + return x.permute(0, 2, 1) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2f86af47..c1ab7b47f 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,16 +16,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +import warnings +from typing import List, Optional, Tuple import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface - -from icefall.utils import add_sos, make_pad_mask from scaling import ScaledLinear +from icefall.utils import add_sos, encode_supervisions, make_pad_mask + class AsrModel(nn.Module): def __init__( @@ -34,11 +35,14 @@ class AsrModel(nn.Module): encoder: EncoderInterface, decoder: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None, + lconv: Optional[nn.Module] = None, + frame_reducer: Optional[nn.Module] = None, encoder_dim: int = 384, decoder_dim: int = 512, vocab_size: int = 500, use_transducer: bool = True, use_ctc: bool = False, + use_bs: bool = True, ): """A joint CTC & Transducer ASR model. @@ -77,6 +81,10 @@ class AsrModel(nn.Module): use_transducer or use_ctc ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" + assert ( + (use_ctc and use_bs) or (use_ctc and not use_bs) or not (use_ctc and use_bs) + ), "Blank Skip needs CTC" + assert isinstance(encoder, EncoderInterface), type(encoder) self.encoder_embed = encoder_embed @@ -111,6 +119,11 @@ class AsrModel(nn.Module): nn.LogSoftmax(dim=-1), ) + self.use_bs = use_bs + if self.use_bs: + self.lconv = lconv + self.frame_reducer = frame_reducer + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -142,217 +155,278 @@ class AsrModel(nn.Module): return encoder_out, encoder_out_lens - def forward_ctc( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - targets: torch.Tensor, - target_lengths: torch.Tensor, - ) -> torch.Tensor: - """Compute CTC loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - targets: - Target Tensor of shape (sum(target_lengths)). The targets are assumed - to be un-padded and concatenated within 1 dimension. - """ - # Compute CTC log-prob - ctc_output = self.ctc_output(encoder_out) # (N, T, C) + def forward_ctc( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + targets: List[int], + target_lengths: torch.Tensor, + supervisions: dict, + subsampling_factor: int, + ctc_beam_size: int, + reduction: str = "sum", + warmup: float = 1.0, + ) -> torch.Tensor: + """Compute CTC loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + targets: + Target Tensor of shape (sum(target_lengths)). The targets are assumed + to be un-padded and concatenated within 1 dimension. + supervisions: + Dict into a pair of torch Tensor, and a list of transcription strings or token indexes + reduction: + Specifies the reduction to apply to the output + """ + # Compute CTC log-prob + ctc_output = self.ctc_output(encoder_out) # (N, T, C) + encoder_out_fr = encoder_out + encoder_out_lens_fr = encoder_out_lens - ctc_loss = torch.nn.functional.ctc_loss( - log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) - targets=targets, - input_lengths=encoder_out_lens, - target_lengths=target_lengths, - reduction="sum", - ) - return ctc_loss + if self.use_bs and warmup >= 2.0: + # lconv + encoder_out = self.lconv( + x=encoder_out, + src_key_padding_mask=make_pad_mask(encoder_out_lens), + ) - def forward_transducer( - self, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - y: k2.RaggedTensor, - y_lens: torch.Tensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute Transducer loss. - Args: - encoder_out: - Encoder output, of shape (N, T, C). - encoder_out_lens: - Encoder output lengths, of shape (N,). - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - """ - # Now for the decoder, i.e., the prediction network - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) + # frame reduce + encoder_out_fr, encoder_out_lens_fr = self.frame_reducer( + encoder_out, + encoder_out_lens, + ctc_output, + target_lengths, + self.decoder.blank_id, + ) - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=subsampling_factor, + token_ids=targets, + ) - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros( - (encoder_out.size(0), 4), - dtype=torch.int64, - device=encoder_out.device, - ) - boundary[:, 2] = y_lens - boundary[:, 3] = encoder_out_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, + # TODO: Crash without this line + supervision_segments = supervision_segments.to("cpu") + decoding_graph = k2.ctc_graph( + token_ids, modified=False, device=encoder_out.device + ) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=subsampling_factor - 1, ) - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=ctc_beam_size, + reduction=reduction, + use_double_scores=True, + ) - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) + return ctc_loss, encoder_out_fr, encoder_out_lens_fr - # logits : [B, T, prune_range, vocab_size] + def forward_transducer( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + y: k2.RaggedTensor, + y_lens: torch.Tensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + delay_penalty: float = 0.0, + reduction: str = "sum", + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute Transducer loss. + Args: + encoder_out: + Encoder output, of shape (N, T, C). + encoder_out_lens: + Encoder output lengths, of shape (N,). + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + reduction: + Specifies the reduction to apply to the output + """ + # Now for the decoder, i.e., the prediction network + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", ) - return simple_loss, pruned_loss + # logits : [B, T, prune_range, vocab_size] - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss) + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + ) - assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + return simple_loss, pruned_loss - # Compute encoder outputs - encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + supervisions: dict, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + subsampling_factor: int = 4, + ctc_beam_size: int = 10, + delay_penalty: float = 0.0, + reduction: str = "sum", + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + reduction: + Specifies the reduction to apply to the output + Returns: + Return the transducer losses and CTC loss, + in form of (simple_loss, pruned_loss, ctc_loss) + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] + assert x.size(0) == x_lens.size(0) == y.dim0 - if self.use_transducer: - # Compute transducer loss - simple_loss, pruned_loss = self.forward_transducer( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - y=y.to(x.device), - y_lens=y_lens, - prune_range=prune_range, - am_scale=am_scale, - lm_scale=lm_scale, - ) - else: - simple_loss = torch.empty(0) - pruned_loss = torch.empty(0) + # Compute encoder outputs + encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) - if self.use_ctc: - # Compute CTC loss - targets = y.values - ctc_loss = self.forward_ctc( - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - targets=targets, - target_lengths=y_lens, - ) - else: - ctc_loss = torch.empty(0) + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] - return simple_loss, pruned_loss, ctc_loss + if self.use_ctc: + # Compute CTC loss + ctc_loss, encoder_out, encoder_out_lens = self.forward_ctc( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + targets=y.tolist(), + target_lengths=y_lens, + supervisions=supervisions, + subsampling_factor=subsampling_factor, + ctc_beam_size=ctc_beam_size, + reduction=reduction, + warmup=warmup, + ) + else: + ctc_loss = torch.empty(0, device=encoder_out.device) + + if self.use_transducer: + # Compute transducer loss + simple_loss, pruned_loss = self.forward_transducer( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + y=y.to(x.device), + y_lens=y_lens, + prune_range=prune_range, + am_scale=am_scale, + lm_scale=lm_scale, + reduction=reduction, + delay_penalty=delay_penalty, + ) + else: + simple_loss = torch.empty(0) + pruned_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 3ccf7d2f1..e2b4ada70 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -67,7 +67,9 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder +from frame_reducer import FrameReducer from joiner import Joiner +from lconv import LConv from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed @@ -258,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head.", ) + parser.add_argument( + "--use-bs", + type=str2bool, + default=False, + help="If True, use blank-skip.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -529,6 +538,7 @@ def get_params() -> AttributeDict: "reset_interval": 200, "valid_interval": 3000, # For the 100h subset, use 800 # parameters for zipformer + "ctc_beam_size": 10, "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. "warm_step": 2000, @@ -583,6 +593,16 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: return encoder +def get_lconv(params: AttributeDict) -> nn.Module: + lconv = LConv(channels=max(params.encoder_dim)) + return lconv + + +def get_frame_reducer(params: AttributeDict) -> nn.Module: + frame_reducer = FrameReducer() + return frame_reducer + + def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, @@ -620,16 +640,24 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None + lconv, frame_reducer = None, None + if params.use_bs: + lconv = get_lconv(params) + frame_reducer = get_frame_reducer(params) + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, + lconv=lconv, + frame_reducer=frame_reducer, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + use_bs=params.use_bs, ) return model @@ -756,6 +784,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + warmup: float, ) -> Tuple[Tensor, MetricsTracker]: """ Compute loss given the model and its inputs. @@ -796,9 +825,12 @@ def compute_loss( x=feature, x_lens=feature_lens, y=y, + supervisions=supervisions, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + ctc_beam_size=params.ctc_beam_size, + warmup=warmup, ) loss = 0.0 @@ -859,6 +891,7 @@ def compute_validation_loss( sp=sp, batch=batch, is_training=False, + warmup=(params.batch_idx_train / params.warm_step), ) assert loss.requires_grad is False tot_loss = tot_loss + loss_info @@ -953,6 +986,7 @@ def train_one_epoch( sp=sp, batch=batch, is_training=True, + warmup=(params.batch_idx_train / params.warm_step), ) # summary stats tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info