diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 2b44dc649..566b3622f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -495,9 +495,6 @@ class ScaledEmbedding(nn.Module): embedding_dim (int): the size of each embedding vector padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` (initialized to zeros) whenever it encounters the index. - max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm` - is renormalized to have norm :attr:`max_norm`. - norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``. scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of the words in the mini-batch. Default ``False``. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. @@ -506,7 +503,7 @@ class ScaledEmbedding(nn.Module): initial_speed (float, optional): This affects how fast the parameter will learn near the start of training; you can set it to a value less than one if you suspect that a module is contributing to instability near - the start of training. Nnote: regardless of the use of this option, + the start of training. Note: regardless of the use of this option, it's best to use schedulers like Noam that have a warm-up period. Alternatively you can set it to more than 1 if you want it to initially train faster. Must be greater than 0. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index c810e36e6..79b178421 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -16,11 +16,11 @@ """ This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, -and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`, -and `nn.Conv2d`. +`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts: +`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`. The scaled version are required only in the training time. It simplifies our -life by converting them their non-scaled version during inference time. +life by converting them to their non-scaled version during inference. """ import copy @@ -28,15 +28,7 @@ import re import torch import torch.nn as nn -from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear - - -def _get_weight(self: torch.nn.Linear): - return self.weight - - -def _get_bias(self: torch.nn.Linear): - return self.bias +from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: @@ -54,10 +46,6 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: """ assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) - # if not hasattr(torch.nn.Linear, "get_weight"): - # torch.nn.Linear.get_weight = _get_weight - # torch.nn.Linear.get_bias = _get_bias - weight = scaled_linear.get_weight() bias = scaled_linear.get_bias() has_bias = bias is not None @@ -148,6 +136,34 @@ def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d: return conv2d +def scaled_embedding_to_embedding( + scaled_embedding: ScaledEmbedding, +) -> nn.Embedding: + """Convert an instance of ScaledEmbedding to nn.Embedding. + + Args: + scaled_embedding: + The layer to be converted. + Returns: + Return an instance of nn.Embedding that has the same `forward()` behavior + of the given `scaled_embedding`. + """ + assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding) + embedding = nn.Embedding( + num_embeddings=scaled_embedding.num_embeddings, + embedding_dim=scaled_embedding.embedding_dim, + padding_idx=scaled_embedding.padding_idx, + scale_grad_by_freq=scaled_embedding.scale_grad_by_freq, + sparse=scaled_embedding.sparse, + ) + weight = scaled_embedding.weight + scale = scaled_embedding.scale + + embedding.weight.data.copy_(weight * scale.exp()) + + return embedding + + def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, @@ -178,6 +194,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): d[name] = scaled_conv1d_to_conv1d(m) elif isinstance(m, ScaledConv2d): d[name] = scaled_conv2d_to_conv2d(m) + elif isinstance(m, ScaledEmbedding): + d[name] = scaled_embedding_to_embedding(m) for k, v in d.items(): if "." in k: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py index 34a9c27f7..a9feea83c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling_converter.py @@ -25,11 +25,12 @@ To run this file, do: import copy import torch -from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear +from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear from scaling_converter import ( convert_scaled_to_non_scaled, scaled_conv1d_to_conv1d, scaled_conv2d_to_conv2d, + scaled_embedding_to_embedding, scaled_linear_to_linear, ) from train import get_params, get_transducer_model @@ -135,6 +136,21 @@ def test_scaled_conv2d_to_conv2d(): assert torch.allclose(y1, y4) +def test_scaled_embedding_to_embedding(): + scaled_embedding = ScaledEmbedding( + num_embeddings=500, + embedding_dim=10, + padding_idx=0, + ) + embedding = scaled_embedding_to_embedding(scaled_embedding) + + for s in [10, 100, 300, 500, 800, 1000]: + x = torch.randint(low=0, high=500, size=(s,)) + scaled_y = scaled_embedding(x) + y = embedding(x) + assert torch.equal(scaled_y, y) + + def test_convert_scaled_to_non_scaled(): for inplace in [False, True]: model = get_model() @@ -193,6 +209,7 @@ def main(): test_scaled_linear_to_linear() test_scaled_conv1d_to_conv1d() test_scaled_conv2d_to_conv2d() + test_scaled_embedding_to_embedding() test_convert_scaled_to_non_scaled() diff --git a/icefall/decode.py b/icefall/decode.py index 3b64481c7..f04ee368c 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -334,10 +334,13 @@ class Nbest(object): if hasattr(lattice, "aux_labels"): # delete token IDs as it is not needed del word_fsa.aux_labels - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops( + word_fsa + ) else: - word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa) - + word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops( + word_fsa + ) path_to_utt_map = self.shape.row_ids(1)