Convert ScaledEmbedding to nn.Embedding for inference. (#517)

* Convert ScaledEmbedding to nn.Embedding for inference.

* Fix CI style issues.
This commit is contained in:
Fangjun Kuang 2022-08-03 15:34:55 +08:00 committed by GitHub
parent 58a96e5b68
commit 6af5a82d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 24 deletions

View File

@ -495,9 +495,6 @@ class ScaledEmbedding(nn.Module):
embedding_dim (int): the size of each embedding vector embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index. (initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``. the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
@ -506,7 +503,7 @@ class ScaledEmbedding(nn.Module):
initial_speed (float, optional): This affects how fast the parameter will initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option, the start of training. Note: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period. it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0. initially train faster. Must be greater than 0.

View File

@ -16,11 +16,11 @@
""" """
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`, `ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
and `nn.Conv2d`. `nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
The scaled version are required only in the training time. It simplifies our The scaled version are required only in the training time. It simplifies our
life by converting them their non-scaled version during inference time. life by converting them to their non-scaled version during inference.
""" """
import copy import copy
@ -28,15 +28,7 @@ import re
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
def _get_weight(self: torch.nn.Linear):
return self.weight
def _get_bias(self: torch.nn.Linear):
return self.bias
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
@ -54,10 +46,6 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
""" """
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
# if not hasattr(torch.nn.Linear, "get_weight"):
# torch.nn.Linear.get_weight = _get_weight
# torch.nn.Linear.get_bias = _get_bias
weight = scaled_linear.get_weight() weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias() bias = scaled_linear.get_bias()
has_bias = bias is not None has_bias = bias is not None
@ -148,6 +136,34 @@ def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
return conv2d return conv2d
def scaled_embedding_to_embedding(
scaled_embedding: ScaledEmbedding,
) -> nn.Embedding:
"""Convert an instance of ScaledEmbedding to nn.Embedding.
Args:
scaled_embedding:
The layer to be converted.
Returns:
Return an instance of nn.Embedding that has the same `forward()` behavior
of the given `scaled_embedding`.
"""
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
embedding = nn.Embedding(
num_embeddings=scaled_embedding.num_embeddings,
embedding_dim=scaled_embedding.embedding_dim,
padding_idx=scaled_embedding.padding_idx,
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
sparse=scaled_embedding.sparse,
)
weight = scaled_embedding.weight
scale = scaled_embedding.scale
embedding.weight.data.copy_(weight * scale.exp())
return embedding
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
@ -178,6 +194,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv1d_to_conv1d(m) d[name] = scaled_conv1d_to_conv1d(m)
elif isinstance(m, ScaledConv2d): elif isinstance(m, ScaledConv2d):
d[name] = scaled_conv2d_to_conv2d(m) d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
for k, v in d.items(): for k, v in d.items():
if "." in k: if "." in k:

View File

@ -25,11 +25,12 @@ To run this file, do:
import copy import copy
import torch import torch
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling_converter import ( from scaling_converter import (
convert_scaled_to_non_scaled, convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d, scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d, scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear, scaled_linear_to_linear,
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
@ -135,6 +136,21 @@ def test_scaled_conv2d_to_conv2d():
assert torch.allclose(y1, y4) assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_convert_scaled_to_non_scaled(): def test_convert_scaled_to_non_scaled():
for inplace in [False, True]: for inplace in [False, True]:
model = get_model() model = get_model()
@ -193,6 +209,7 @@ def main():
test_scaled_linear_to_linear() test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d() test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d() test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_convert_scaled_to_non_scaled() test_convert_scaled_to_non_scaled()

View File

@ -334,10 +334,13 @@ class Nbest(object):
if hasattr(lattice, "aux_labels"): if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed # delete token IDs as it is not needed
del word_fsa.aux_labels del word_fsa.aux_labels
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
word_fsa
)
else: else:
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa) word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
word_fsa
)
path_to_utt_map = self.shape.row_ids(1) path_to_utt_map = self.shape.row_ids(1)