Convert ScaledEmbedding to nn.Embedding for inference. (#517)

* Convert ScaledEmbedding to nn.Embedding for inference.

* Fix CI style issues.
This commit is contained in:
Fangjun Kuang 2022-08-03 15:34:55 +08:00 committed by GitHub
parent 58a96e5b68
commit 6af5a82d8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 24 deletions

View File

@ -495,9 +495,6 @@ class ScaledEmbedding(nn.Module):
embedding_dim (int): the size of each embedding vector
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
(initialized to zeros) whenever it encounters the index.
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
is renormalized to have norm :attr:`max_norm`.
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
the words in the mini-batch. Default ``False``.
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
@ -506,7 +503,7 @@ class ScaledEmbedding(nn.Module):
initial_speed (float, optional): This affects how fast the parameter will
learn near the start of training; you can set it to a value less than
one if you suspect that a module is contributing to instability near
the start of training. Nnote: regardless of the use of this option,
the start of training. Note: regardless of the use of this option,
it's best to use schedulers like Noam that have a warm-up period.
Alternatively you can set it to more than 1 if you want it to
initially train faster. Must be greater than 0.

View File

@ -16,11 +16,11 @@
"""
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`,
and `nn.Conv2d`.
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
The scaled version are required only in the training time. It simplifies our
life by converting them their non-scaled version during inference time.
life by converting them to their non-scaled version during inference.
"""
import copy
@ -28,15 +28,7 @@ import re
import torch
import torch.nn as nn
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
def _get_weight(self: torch.nn.Linear):
return self.weight
def _get_bias(self: torch.nn.Linear):
return self.bias
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
@ -54,10 +46,6 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
# if not hasattr(torch.nn.Linear, "get_weight"):
# torch.nn.Linear.get_weight = _get_weight
# torch.nn.Linear.get_bias = _get_bias
weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias()
has_bias = bias is not None
@ -148,6 +136,34 @@ def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
return conv2d
def scaled_embedding_to_embedding(
scaled_embedding: ScaledEmbedding,
) -> nn.Embedding:
"""Convert an instance of ScaledEmbedding to nn.Embedding.
Args:
scaled_embedding:
The layer to be converted.
Returns:
Return an instance of nn.Embedding that has the same `forward()` behavior
of the given `scaled_embedding`.
"""
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
embedding = nn.Embedding(
num_embeddings=scaled_embedding.num_embeddings,
embedding_dim=scaled_embedding.embedding_dim,
padding_idx=scaled_embedding.padding_idx,
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
sparse=scaled_embedding.sparse,
)
weight = scaled_embedding.weight
scale = scaled_embedding.scale
embedding.weight.data.copy_(weight * scale.exp())
return embedding
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
@ -178,6 +194,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
d[name] = scaled_conv1d_to_conv1d(m)
elif isinstance(m, ScaledConv2d):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
for k, v in d.items():
if "." in k:

View File

@ -25,11 +25,12 @@ To run this file, do:
import copy
import torch
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
from scaling_converter import (
convert_scaled_to_non_scaled,
scaled_conv1d_to_conv1d,
scaled_conv2d_to_conv2d,
scaled_embedding_to_embedding,
scaled_linear_to_linear,
)
from train import get_params, get_transducer_model
@ -135,6 +136,21 @@ def test_scaled_conv2d_to_conv2d():
assert torch.allclose(y1, y4)
def test_scaled_embedding_to_embedding():
scaled_embedding = ScaledEmbedding(
num_embeddings=500,
embedding_dim=10,
padding_idx=0,
)
embedding = scaled_embedding_to_embedding(scaled_embedding)
for s in [10, 100, 300, 500, 800, 1000]:
x = torch.randint(low=0, high=500, size=(s,))
scaled_y = scaled_embedding(x)
y = embedding(x)
assert torch.equal(scaled_y, y)
def test_convert_scaled_to_non_scaled():
for inplace in [False, True]:
model = get_model()
@ -193,6 +209,7 @@ def main():
test_scaled_linear_to_linear()
test_scaled_conv1d_to_conv1d()
test_scaled_conv2d_to_conv2d()
test_scaled_embedding_to_embedding()
test_convert_scaled_to_non_scaled()

View File

@ -334,10 +334,13 @@ class Nbest(object):
if hasattr(lattice, "aux_labels"):
# delete token IDs as it is not needed
del word_fsa.aux_labels
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
word_fsa
)
else:
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
word_fsa
)
path_to_utt_map = self.shape.row_ids(1)