diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py new file mode 100644 index 000000000..281785c4e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/diagonalize.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# Copyright (c) 2022 Xiaomi Corporation (author: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import Optional, Tuple +import logging +import torch +from torch import Tensor, nn + +# some utilities for diagnalizing models (rotating their parameters matrices +# so that large and small parameter values are separated as much as possible). + +def _get_normalized_covar(x: Tensor) -> Tensor: + """ + Returns a covariance matrix normalized to have trace==dim, equal to + matmul(x , x.t()) times a constant. + Args: + x: a matrix of shape (i, j) + Returns: a covariance matrix of shape (i, i), equal to matmul(x, x.t()) + """ + covar = torch.matmul(x, x.t()) + return covar * (x.shape[0] / (covar.trace() + 1.0e-20)) + + +@torch.no_grad() +def get_diag_covar_in(m: nn.Module) -> Tensor: + """ + Returns a covariance matrix that shows, in the input space of + this module, which direction parameter matrices vary in. + """ + if isinstance(m, nn.Linear): + return _get_normalized_covar(m.weight.t()); + elif isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): + # m.weight is of size (out_channels, in_channels, kernel_size) + # or (out_channels, in_channels, kernel_dim0, kernel_dim1) + # assert here that groups == 1 + w = m.weight + assert m.groups == 1 + out_channels = w.shape[0] + in_channels = w.shape[1] + w = w.reshape(out_channels, in_channels, -1) + w = w.permute(1, 0, 2) # (in_channels, out_channels, kernel_size) + w = w.reshape(in_channels, -1) + return _get_normalized_covar(w) # (in_channels, in_channels) + elif isinstance(m, nn.Sequential): + return get_diag_covar_in(m[0]) + else: + # some modules have this function; if not, at this point, it is an error. + return m.get_diag_covar_in() + +@torch.no_grad() +def get_diag_covar_out(m: nn.Module) -> Tensor: + """ + Returns a covariance matrix that shows, in the output space of + this module, which direction parameter matrices vary in. + """ + if isinstance(m, nn.Linear): + return _get_normalized_covar(m.weight); + elif isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): + # m.weight is of size (out_channels, in_channels, kernel_size) + # or (out_channels, in_channels, kernel_dim0, kernel_dim1) + # assert here that groups == 1 + w = m.weight + assert m.groups == 1 + out_channels = w.shape[0] + in_channels = w.shape[1] + w = w.reshape(out_channels, -1) + return _get_normalized_covar(w) # (out_channels, out_channels) + + w = w.permute(1, 0, 2) # (in_channels, out_channels, kernel_size) + w = w.reshape(in_channels, -1) + return _get_normalized_covar(x) # (in_channels, in_channels) + elif isinstance(m, nn.Sequential): + return get_diag_covar_out(m[-1]) + else: + # some modules have this function; if not, at this point, it is an error. + return m.get_diag_covar_out() + +@torch.no_grad() +def get_diag_covar_inout(m: nn.Module) -> Tensor: + """ + Returns a covariance matrix that shows, in the input and + output space of this module, which are assumed to be the + same (e.g if it is a module intended to be added to a residual/ + bypass connection), + which direction parameter matrices vary in. + """ + if isinstance(m, nn.Sequential): + # this is only correct if it's a Sequential of non-residual modules. + return get_diag_covar_in(m[0]) + get_diag_covar_out(m[-1]) + else: + # some modules have this function; if not, at this point, it is an error. + return m.get_diag_covar_inout() + + +@torch.no_grad() +def apply_transformation_in(m: nn.Module, t: Tensor) -> None: + """ + Applies this transformation matrix on the input space of this module. + Args: + m: module to transform on the input space + t: transformation matrix, indexed (new_dim_in, old_dim_in) + """ + if isinstance(m, nn.Linear): + m.weight[:] = torch.matmul(m.weight, t.t()) + elif isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): + # m.weight is of size (out_channels, in_channels, kernel_size) + # or (out_channels, in_channels, kernel_dim0, kernel_dim1) + # assert here that groups == 1 + w = m.weight + assert m.groups == 1 + out_channels = w.shape[0] + in_channels = w.shape[1] + w = w.reshape(out_channels, in_channels, -1) + w = w.permute(1, 0, 2) # (in_channels, out_channels, kernel_size) + w = w.reshape(in_channels, -1) + w = torch.matmul(t, w).reshape(in_channels, out_channels, -1) # (in_channels, out_channels, kernel_size) + w = w.permute(1, 0, 2) # (out_channels, in_channels, kernel_size) + w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims]) + m.weight[:] = w + elif isinstance(m, nn.Sequential): + apply_transformation_in(m[0]) + else: + # some modules have this function; if not, at this point, it is an error. + m.apply_transformation_in(t) + +@torch.no_grad() +def apply_transformation_out(m: nn.Module, t: Tensor) -> None: + """ + Applies this transformation matrix on the output space of this module. + Args: + m: module to transform on the input space + t: transformation matrix, indexed (new_dim_out, old_dim_out) + """ + if isinstance(m, nn.Linear): + m.weight[:] = torch.matmul(t, m.weight) + if m.bias is not None: + m.bias[:] = torch.matmul(t, m.bias) + elif isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): + # m.weight is of size (out_channels, in_channels, kernel_size) + # or (out_channels, in_channels, kernel_dim0, kernel_dim1) + # assert here that groups == 1 + w = m.weight + assert m.groups == 1 + out_channels = w.shape[0] + in_channels = w.shape[1] + w = w.reshape(out_channels, -1) + w = torch.matmul(t, w) + w = w.reshape(m.weight.shape) # (out_channels, in_channels, [1 or 2 kernel dims]) + m.weight[:] = w + if m.bias is not None: + m.bias[:] = torch.matmul(t, m.bias) + elif isinstance(m, nn.Sequential): + apply_transformation_out(m[-1]) + else: + # some modules have this function; if not, at this point, it is an error. + m.apply_transformation_out(t) + + +@torch.no_grad() +def apply_transformation_inout(m: nn.Module, t: Tensor) -> None: + if isinstance(m, nn.Sequential): + apply_transformation_in(m, t) + apply_transformation_out(m, t) + else: + # some modules have this function; if not, at this point, it is an error. + m.apply_transformation_inout(t) + + +def get_transformation(cov: Tensor) -> Tensor: + """ + Returns a covariance-diagonalizing transformation that diagonalizes + the covariance matrix that is passed in. + + Args: cov, of shape (dim0, dim0). + + Returns: a transformation indexed (new_dim0, old_dim0), i.e. of + shape dim0 by dim0 but 1st index is the newly created indexes. + """ + cov = get_normalized_covar(args[0]) + for a in args[1:]: + cov += get_normalized_covar(a) + old_diag_stddev = cov.diag().var().sqrt().item() + l, U = cov.symeig(eigenvectors=True) + new_diag_stddev = l.var().sqrt().item() + logging.info(f"Variance of diag of param-var changed from {old_diag_stddev:.3e} " + f"to {new_diag_stddev:.3e}") + return U.t() # U.t() is indexed (new_dim, old_dim) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py deleted file mode 120000 index ebb6d774d..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py +++ /dev/null @@ -1 +0,0 @@ -../pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py new file mode 100644 index 000000000..01b9ecc0e --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless4b/model.py @@ -0,0 +1,201 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import k2 +import torch +from torch import Tensor +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear +from diagonalize import get_diag_covar_in + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) + + + def get_diag_covar_in(self) -> Tensor: + return (get_diag_covar_in(self.simple_am_proj) + + get_diag_covar_in(joiner.encoder_proj) + + self.encoder.get_diag_covar_out())