diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index 142067f1a..f59032db8 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -15,13 +15,39 @@ # limitations under the License. +from typing import Optional + import k2 import torch import torch.nn as nn from encoder_interface import EncoderInterface from scaling import ScaledLinear -from icefall.utils import add_sos +from icefall.utils import add_sos, make_pad_mask + + +def compute_teacher_student_loss( + encoder_out: torch.Tensor, + teacher_encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> torch.Tensor: + """ + Args: + encoder_out: + Encoder output of the student. Its shape is (N, T, C) + teacher_encoder_out: + Encoder output of the teacher. Its shape is also (N, T, C) + encoder_out_lens: + A 1-D tensor containing the number of valid frames in encoder_out before + padding. + Returns: + Return the l1 loss between encoder_out and teacher_encoder_out. + """ + loss = (encoder_out - teacher_encoder_out).abs().sum(dim=-1) + mask = make_pad_mask(encoder_out_lens) + loss.masked_fill_(mask, 0) + + return loss.sum() / encoder_out.size(-1) class Transducer(nn.Module): @@ -51,9 +77,10 @@ class Transducer(nn.Module): is (N, U) and its output shape is (N, U, decoder_dim). It should contain one attribute: `blank_id`. joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) @@ -76,6 +103,7 @@ class Transducer(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, + teacher_model: Optional[torch.jit.ScriptModule] = None, ) -> torch.Tensor: """ Args: @@ -96,6 +124,8 @@ class Transducer(nn.Module): lm_scale: The scale to smooth the loss with lm (output of predictor network) part + teacher_model: + The teacher model. Returns: Return the transducer loss. @@ -111,8 +141,20 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) + encoder_out, encoder_out_lens = self.encoder(x, x_lens) + assert torch.all(encoder_out_lens > 0) + + if self.training is True: + with torch.no_grad(): + teacher_encoder_out, _ = teacher_model.encoder(x, x_lens) + + ts_loss = compute_teacher_student_loss( + encoder_out, + teacher_encoder_out, + encoder_out_lens, + ) + else: + ts_loss = torch.tensor([0.0]) # Now for the decoder, i.e., the prediction network row_splits = y.shape.row_splits(1) @@ -136,7 +178,7 @@ class Transducer(nn.Module): (x.size(0), 4), dtype=torch.int64, device=x.device ) boundary[:, 2] = y_lens - boundary[:, 3] = x_lens + boundary[:, 3] = encoder_out_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) @@ -186,4 +228,4 @@ class Transducer(nn.Module): reduction="sum", ) - return (simple_loss, pruned_loss) + return (simple_loss, pruned_loss, ts_loss) diff --git a/egs/librispeech/ASR/transducer_lstm/teacher_model.py b/egs/librispeech/ASR/transducer_lstm/teacher_model.py new file mode 100644 index 000000000..0ad31ba7e --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/teacher_model.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def get_teacher_model() -> torch.jit.ScriptModule: + filename = "/ceph-fj/fangjun/open-source-2/icefall-master-2/egs/librispeech/ASR/pruned_transducer_stateless3/exp/cpu_jit.pt" + model = torch.jit.load(filename) + + return model diff --git a/egs/librispeech/ASR/transducer_lstm/test_teacher_model.py b/egs/librispeech/ASR/transducer_lstm/test_teacher_model.py new file mode 100755 index 000000000..3204c11f4 --- /dev/null +++ b/egs/librispeech/ASR/transducer_lstm/test_teacher_model.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./transducer_lstm/test_teacher_model.py +""" + +import warnings + +import torch +from teacher_model import get_teacher_model + + +def test_teacher_model(): + model = get_teacher_model() + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of encoder model parameters: {num_param}") + + N = 3 + T = 500 + C = 80 + + x = torch.rand(N, T, C) + x_lens = torch.tensor([100, 500, 300]) + + y, y_lens = model.encoder(x, x_lens) + print(y.shape) + expected_y_lens = (((x_lens - 1) >> 1) - 1) >> 1 + + assert torch.all(torch.eq(y_lens, expected_y_lens)), ( + y_lens, + expected_y_lens, + ) + + +def main(): + test_teacher_model() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 17618b415..5692f91a1 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -65,6 +65,7 @@ from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, Eve +from teacher_model import get_teacher_model from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -229,6 +230,16 @@ def get_parser(): "with this parameter before adding to the final loss.", ) + parser.add_argument( + "--ts-loss-scale", + type=float, + default=0.1, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + parser.add_argument( "--seed", type=int, @@ -548,6 +559,7 @@ def compute_loss( sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, + teacher_model: Optional[torch.jit.ScriptModule] = None, ) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -564,6 +576,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + teacher_model: + The teacher model. """ device = params.device feature = batch["inputs"] @@ -579,16 +593,18 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( + simple_loss, pruned_loss, ts_loss = model( x=feature, x_lens=feature_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, + teacher_model=teacher_model, ) loss = params.simple_loss_scale * simple_loss + pruned_loss + loss = loss + params.ts_loss_scale * ts_loss assert loss.requires_grad == is_training @@ -603,6 +619,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ts_loss"] = ts_loss.detach().cpu().item() return loss, info @@ -623,6 +640,7 @@ def compute_validation_loss( loss, loss_info = compute_loss( params=params, model=model, + teacher_model=None, sp=sp, batch=batch, is_training=False, @@ -644,6 +662,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, model: nn.Module, + teacher_model: Optional[torch.jit.ScriptModule], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, @@ -666,6 +685,8 @@ def train_one_epoch( It is returned by :func:`get_params`. model: The model for training. + teacher_model: + The teacher model. optimizer: The optimizer we are using. scheduler: @@ -729,6 +750,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, + teacher_model=teacher_model, sp=sp, batch=batch, is_training=True, @@ -901,10 +923,14 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) + teacher_model = get_teacher_model() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + num_teacher_param = sum([p.numel() for p in teacher_model.parameters()]) + logging.info(f"Number of teacher model parameters: {num_teacher_param}") + assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None if rank == 0: @@ -919,6 +945,7 @@ def run(rank, world_size, args): ) model.to(device) + teacher_model.to(device) if world_size > 1: logging.info("Using DDP") model = DDP(model, device_ids=[rank]) @@ -983,6 +1010,7 @@ def run(rank, world_size, args): if not params.print_diagnostics: scan_pessimistic_batches_for_oom( model=model, + teacher_model=teacher_model, train_dl=train_dl, optimizer=optimizer, sp=sp, @@ -1007,6 +1035,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + teacher_model=teacher_model, model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, @@ -1075,6 +1104,7 @@ def display_and_save_batch( def scan_pessimistic_batches_for_oom( model: nn.Module, + teacher_model: torch.jit.ScriptModule, train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, @@ -1093,6 +1123,7 @@ def scan_pessimistic_batches_for_oom( loss, _ = compute_loss( params=params, model=model, + teacher_model=teacher_model, sp=sp, batch=batch, is_training=True, @@ -1130,5 +1161,9 @@ def main(): torch.set_num_threads(1) torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) + if __name__ == "__main__": main()