mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Convert ScaledEmbedding to nn.Embedding for inference. (#517)
* Convert ScaledEmbedding to nn.Embedding for inference. * Fix CI style issues.
This commit is contained in:
parent
58a96e5b68
commit
6af5a82d8f
@ -495,9 +495,6 @@ class ScaledEmbedding(nn.Module):
|
|||||||
embedding_dim (int): the size of each embedding vector
|
embedding_dim (int): the size of each embedding vector
|
||||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||||
(initialized to zeros) whenever it encounters the index.
|
(initialized to zeros) whenever it encounters the index.
|
||||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
||||||
is renormalized to have norm :attr:`max_norm`.
|
|
||||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
||||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||||
the words in the mini-batch. Default ``False``.
|
the words in the mini-batch. Default ``False``.
|
||||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||||
@ -506,7 +503,7 @@ class ScaledEmbedding(nn.Module):
|
|||||||
initial_speed (float, optional): This affects how fast the parameter will
|
initial_speed (float, optional): This affects how fast the parameter will
|
||||||
learn near the start of training; you can set it to a value less than
|
learn near the start of training; you can set it to a value less than
|
||||||
one if you suspect that a module is contributing to instability near
|
one if you suspect that a module is contributing to instability near
|
||||||
the start of training. Nnote: regardless of the use of this option,
|
the start of training. Note: regardless of the use of this option,
|
||||||
it's best to use schedulers like Noam that have a warm-up period.
|
it's best to use schedulers like Noam that have a warm-up period.
|
||||||
Alternatively you can set it to more than 1 if you want it to
|
Alternatively you can set it to more than 1 if you want it to
|
||||||
initially train faster. Must be greater than 0.
|
initially train faster. Must be greater than 0.
|
||||||
|
@ -16,11 +16,11 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
|
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
|
||||||
and `ScaledConv2d` to their non-scaled counterparts: `nn.Linear`, `nn.Conv1d`,
|
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
|
||||||
and `nn.Conv2d`.
|
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
|
||||||
|
|
||||||
The scaled version are required only in the training time. It simplifies our
|
The scaled version are required only in the training time. It simplifies our
|
||||||
life by converting them their non-scaled version during inference time.
|
life by converting them to their non-scaled version during inference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
@ -28,15 +28,7 @@ import re
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
def _get_weight(self: torch.nn.Linear):
|
|
||||||
return self.weight
|
|
||||||
|
|
||||||
|
|
||||||
def _get_bias(self: torch.nn.Linear):
|
|
||||||
return self.bias
|
|
||||||
|
|
||||||
|
|
||||||
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
||||||
@ -54,10 +46,6 @@ def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
|||||||
"""
|
"""
|
||||||
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
||||||
|
|
||||||
# if not hasattr(torch.nn.Linear, "get_weight"):
|
|
||||||
# torch.nn.Linear.get_weight = _get_weight
|
|
||||||
# torch.nn.Linear.get_bias = _get_bias
|
|
||||||
|
|
||||||
weight = scaled_linear.get_weight()
|
weight = scaled_linear.get_weight()
|
||||||
bias = scaled_linear.get_bias()
|
bias = scaled_linear.get_bias()
|
||||||
has_bias = bias is not None
|
has_bias = bias is not None
|
||||||
@ -148,6 +136,34 @@ def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
|
|||||||
return conv2d
|
return conv2d
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_embedding_to_embedding(
|
||||||
|
scaled_embedding: ScaledEmbedding,
|
||||||
|
) -> nn.Embedding:
|
||||||
|
"""Convert an instance of ScaledEmbedding to nn.Embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scaled_embedding:
|
||||||
|
The layer to be converted.
|
||||||
|
Returns:
|
||||||
|
Return an instance of nn.Embedding that has the same `forward()` behavior
|
||||||
|
of the given `scaled_embedding`.
|
||||||
|
"""
|
||||||
|
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
|
||||||
|
embedding = nn.Embedding(
|
||||||
|
num_embeddings=scaled_embedding.num_embeddings,
|
||||||
|
embedding_dim=scaled_embedding.embedding_dim,
|
||||||
|
padding_idx=scaled_embedding.padding_idx,
|
||||||
|
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
|
||||||
|
sparse=scaled_embedding.sparse,
|
||||||
|
)
|
||||||
|
weight = scaled_embedding.weight
|
||||||
|
scale = scaled_embedding.scale
|
||||||
|
|
||||||
|
embedding.weight.data.copy_(weight * scale.exp())
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
||||||
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
||||||
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
||||||
@ -178,6 +194,8 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
|||||||
d[name] = scaled_conv1d_to_conv1d(m)
|
d[name] = scaled_conv1d_to_conv1d(m)
|
||||||
elif isinstance(m, ScaledConv2d):
|
elif isinstance(m, ScaledConv2d):
|
||||||
d[name] = scaled_conv2d_to_conv2d(m)
|
d[name] = scaled_conv2d_to_conv2d(m)
|
||||||
|
elif isinstance(m, ScaledEmbedding):
|
||||||
|
d[name] = scaled_embedding_to_embedding(m)
|
||||||
|
|
||||||
for k, v in d.items():
|
for k, v in d.items():
|
||||||
if "." in k:
|
if "." in k:
|
||||||
|
@ -25,11 +25,12 @@ To run this file, do:
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from scaling import ScaledConv1d, ScaledConv2d, ScaledLinear
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
|
||||||
from scaling_converter import (
|
from scaling_converter import (
|
||||||
convert_scaled_to_non_scaled,
|
convert_scaled_to_non_scaled,
|
||||||
scaled_conv1d_to_conv1d,
|
scaled_conv1d_to_conv1d,
|
||||||
scaled_conv2d_to_conv2d,
|
scaled_conv2d_to_conv2d,
|
||||||
|
scaled_embedding_to_embedding,
|
||||||
scaled_linear_to_linear,
|
scaled_linear_to_linear,
|
||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
@ -135,6 +136,21 @@ def test_scaled_conv2d_to_conv2d():
|
|||||||
assert torch.allclose(y1, y4)
|
assert torch.allclose(y1, y4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scaled_embedding_to_embedding():
|
||||||
|
scaled_embedding = ScaledEmbedding(
|
||||||
|
num_embeddings=500,
|
||||||
|
embedding_dim=10,
|
||||||
|
padding_idx=0,
|
||||||
|
)
|
||||||
|
embedding = scaled_embedding_to_embedding(scaled_embedding)
|
||||||
|
|
||||||
|
for s in [10, 100, 300, 500, 800, 1000]:
|
||||||
|
x = torch.randint(low=0, high=500, size=(s,))
|
||||||
|
scaled_y = scaled_embedding(x)
|
||||||
|
y = embedding(x)
|
||||||
|
assert torch.equal(scaled_y, y)
|
||||||
|
|
||||||
|
|
||||||
def test_convert_scaled_to_non_scaled():
|
def test_convert_scaled_to_non_scaled():
|
||||||
for inplace in [False, True]:
|
for inplace in [False, True]:
|
||||||
model = get_model()
|
model = get_model()
|
||||||
@ -193,6 +209,7 @@ def main():
|
|||||||
test_scaled_linear_to_linear()
|
test_scaled_linear_to_linear()
|
||||||
test_scaled_conv1d_to_conv1d()
|
test_scaled_conv1d_to_conv1d()
|
||||||
test_scaled_conv2d_to_conv2d()
|
test_scaled_conv2d_to_conv2d()
|
||||||
|
test_scaled_embedding_to_embedding()
|
||||||
test_convert_scaled_to_non_scaled()
|
test_convert_scaled_to_non_scaled()
|
||||||
|
|
||||||
|
|
||||||
|
@ -334,10 +334,13 @@ class Nbest(object):
|
|||||||
if hasattr(lattice, "aux_labels"):
|
if hasattr(lattice, "aux_labels"):
|
||||||
# delete token IDs as it is not needed
|
# delete token IDs as it is not needed
|
||||||
del word_fsa.aux_labels
|
del word_fsa.aux_labels
|
||||||
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa)
|
word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(
|
||||||
|
word_fsa
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa)
|
word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(
|
||||||
|
word_fsa
|
||||||
|
)
|
||||||
|
|
||||||
path_to_utt_map = self.shape.row_ids(1)
|
path_to_utt_map = self.shape.row_ids(1)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user