mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Use learnable scales for joiner and decoder
This commit is contained in:
parent
2117f46361
commit
6042c96db2
@ -17,6 +17,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Optional
|
||||||
|
from subsampling import ScaledConv1d
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -52,7 +55,7 @@ class Decoder(nn.Module):
|
|||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = ScaledEmbedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
@ -62,7 +65,7 @@ class Decoder(nn.Module):
|
|||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
if context_size > 1:
|
if context_size > 1:
|
||||||
self.conv = nn.Conv1d(
|
self.conv = ScaledConv1d(
|
||||||
in_channels=embedding_dim,
|
in_channels=embedding_dim,
|
||||||
out_channels=embedding_dim,
|
out_channels=embedding_dim,
|
||||||
kernel_size=context_size,
|
kernel_size=context_size,
|
||||||
@ -97,3 +100,183 @@ class Decoder(nn.Module):
|
|||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledEmbedding(nn.Module):
|
||||||
|
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||||
|
|
||||||
|
This module is often used to store word embeddings and retrieve them using indices.
|
||||||
|
The input to the module is a list of indices, and the output is the corresponding
|
||||||
|
word embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_embeddings (int): size of the dictionary of embeddings
|
||||||
|
embedding_dim (int): the size of each embedding vector
|
||||||
|
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
||||||
|
(initialized to zeros) whenever it encounters the index.
|
||||||
|
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
||||||
|
is renormalized to have norm :attr:`max_norm`.
|
||||||
|
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
||||||
|
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
||||||
|
the words in the mini-batch. Default ``False``.
|
||||||
|
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
||||||
|
See Notes for more details regarding sparse gradients.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||||
|
initialized from :math:`\mathcal{N}(0, 1)`
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
||||||
|
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
Keep in mind that only a limited number of optimizers support
|
||||||
|
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
||||||
|
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
With :attr:`padding_idx` set, the embedding vector at
|
||||||
|
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
||||||
|
vector can be modified afterwards, e.g., using a customized
|
||||||
|
initialization method, and thus changing the vector used to pad the
|
||||||
|
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
||||||
|
is always zero.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> # an Embedding module containing 10 tensors of size 3
|
||||||
|
>>> embedding = nn.Embedding(10, 3)
|
||||||
|
>>> # a batch of 2 samples of 4 indices each
|
||||||
|
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
||||||
|
>>> embedding(input)
|
||||||
|
tensor([[[-0.0251, -1.6902, 0.7172],
|
||||||
|
[-0.6431, 0.0748, 0.6969],
|
||||||
|
[ 1.4970, 1.3448, -0.9685],
|
||||||
|
[-0.3677, -2.7265, -0.1685]],
|
||||||
|
|
||||||
|
[[ 1.4970, 1.3448, -0.9685],
|
||||||
|
[ 0.4362, -0.4004, 0.9400],
|
||||||
|
[-0.6431, 0.0748, 0.6969],
|
||||||
|
[ 0.9124, -2.3616, 1.1151]]])
|
||||||
|
|
||||||
|
|
||||||
|
>>> # example with padding_idx
|
||||||
|
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
||||||
|
>>> input = torch.LongTensor([[0,2,0,5]])
|
||||||
|
>>> embedding(input)
|
||||||
|
tensor([[[ 0.0000, 0.0000, 0.0000],
|
||||||
|
[ 0.1535, -2.0309, 0.9315],
|
||||||
|
[ 0.0000, 0.0000, 0.0000],
|
||||||
|
[-0.1655, 0.9897, 0.0635]]])
|
||||||
|
"""
|
||||||
|
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
||||||
|
'scale_grad_by_freq', 'sparse']
|
||||||
|
|
||||||
|
num_embeddings: int
|
||||||
|
embedding_dim: int
|
||||||
|
padding_idx: int
|
||||||
|
scale_grad_by_freq: bool
|
||||||
|
weight: Tensor
|
||||||
|
sparse: bool
|
||||||
|
|
||||||
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False, _weight: Optional[Tensor] = None,
|
||||||
|
scale_speed: float = 5.0) -> None:
|
||||||
|
super(ScaledEmbedding, self).__init__()
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
if padding_idx is not None:
|
||||||
|
if padding_idx > 0:
|
||||||
|
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||||
|
elif padding_idx < 0:
|
||||||
|
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||||||
|
padding_idx = self.num_embeddings + padding_idx
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
|
||||||
|
self.scale_speed = scale_speed
|
||||||
|
self.scale = nn.Parameter(torch.tensor(embedding_dim**0.5).log() / scale_speed)
|
||||||
|
|
||||||
|
if _weight is None:
|
||||||
|
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
||||||
|
self.reset_parameters()
|
||||||
|
else:
|
||||||
|
assert list(_weight.shape) == [num_embeddings, embedding_dim], \
|
||||||
|
'Shape of weight does not match num_embeddings and embedding_dim'
|
||||||
|
self.weight = nn.Parameter(_weight)
|
||||||
|
self.sparse = sparse
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
nn.init.normal_(self.weight, std=self.embedding_dim**-0.5)
|
||||||
|
if self.padding_idx is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
scale = (self.scale * self.scale_speed).exp()
|
||||||
|
if input.numel() < self.num_embeddings:
|
||||||
|
return F.embedding(
|
||||||
|
input, self.weight, self.padding_idx,
|
||||||
|
None, 2.0, # None, 2.0 relate to normalization
|
||||||
|
self.scale_grad_by_freq, self.sparse) * scale
|
||||||
|
else:
|
||||||
|
return F.embedding(
|
||||||
|
input, self.weight * scale, self.padding_idx,
|
||||||
|
None, 2.0, # None, 2.0 relates to normalization
|
||||||
|
self.scale_grad_by_freq, self.sparse)
|
||||||
|
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
|
||||||
|
if self.padding_idx is not None:
|
||||||
|
s += ', padding_idx={padding_idx}'
|
||||||
|
if self.scale_grad_by_freq is not False:
|
||||||
|
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
||||||
|
if self.sparse is not False:
|
||||||
|
s += ', sparse=True'
|
||||||
|
return s.format(**self.__dict__)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
|
||||||
|
max_norm=None, norm_type=2., scale_grad_by_freq=False,
|
||||||
|
sparse=False):
|
||||||
|
r"""Creates Embedding instance from given 2-dimensional FloatTensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings (Tensor): FloatTensor containing weights for the Embedding.
|
||||||
|
First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
|
||||||
|
freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
|
||||||
|
Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
|
||||||
|
padding_idx (int, optional): See module initialization documentation.
|
||||||
|
max_norm (float, optional): See module initialization documentation.
|
||||||
|
norm_type (float, optional): See module initialization documentation. Default ``2``.
|
||||||
|
scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
|
||||||
|
sparse (bool, optional): See module initialization documentation.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> # FloatTensor containing pretrained weights
|
||||||
|
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
|
||||||
|
>>> embedding = nn.Embedding.from_pretrained(weight)
|
||||||
|
>>> # Get embeddings for index 1
|
||||||
|
>>> input = torch.LongTensor([1])
|
||||||
|
>>> embedding(input)
|
||||||
|
tensor([[ 4.0000, 5.1000, 6.3000]])
|
||||||
|
"""
|
||||||
|
assert embeddings.dim() == 2, \
|
||||||
|
'Embeddings parameter is expected to be 2-dimensional'
|
||||||
|
rows, cols = embeddings.shape
|
||||||
|
embedding = cls(
|
||||||
|
num_embeddings=rows,
|
||||||
|
embedding_dim=cols,
|
||||||
|
_weight=embeddings,
|
||||||
|
padding_idx=padding_idx,
|
||||||
|
max_norm=max_norm,
|
||||||
|
norm_type=norm_type,
|
||||||
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
|
sparse=sparse)
|
||||||
|
embedding.weight.requires_grad = not freeze
|
||||||
|
return embedding
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from subsampling import ScaledLinear
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
@ -24,7 +24,7 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
self.output_linear = ScaledLinear(input_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2b",
|
default="transducer_stateless/randcombine1_expscale3_rework2c",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ class Transformer(EncoderInterface):
|
|||||||
|
|
||||||
# TODO(fangjun): remove dropout
|
# TODO(fangjun): remove dropout
|
||||||
self.encoder_output_layer = nn.Sequential(
|
self.encoder_output_layer = nn.Sequential(
|
||||||
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
|
nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user