mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Revert transducer_stateless/ to state in upstream/master
This commit is contained in:
parent
807fcada68
commit
9f62a0296c
@ -18,8 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Sequence
|
from typing import Optional, Tuple
|
||||||
from subsampling import DoubleSwish, ActivationBalancer, BasicNorm, ScaledLinear, ScaledConv1d, ScaledConv2d
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -57,7 +56,6 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
aux_layer_period: int = 3
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -82,13 +80,17 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
normalize_before,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
|
if self.normalize_before:
|
||||||
|
self.after_norm = nn.LayerNorm(d_model)
|
||||||
|
else:
|
||||||
|
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||||
|
# and throws an error without this change.
|
||||||
|
self.after_norm = identity
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -115,8 +117,10 @@ class Conformer(Transformer):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
||||||
warmup_mode=warmup_mode) # (T, N, C)
|
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.after_norm(x)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(x)
|
logits = self.encoder_output_layer(x)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -154,41 +158,42 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.d_model = d_model
|
|
||||||
|
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
d_model, nhead, dropout=0.0
|
d_model, nhead, dropout=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
ActivationBalancer(channel_dim=-1),
|
Swish(),
|
||||||
DoubleSwish(),
|
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
ScaledLinear(d_model, dim_feedforward),
|
nn.Linear(d_model, dim_feedforward),
|
||||||
ActivationBalancer(channel_dim=-1),
|
Swish(),
|
||||||
DoubleSwish(),
|
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||||
|
|
||||||
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
|
d_model
|
||||||
|
) # for the macaron style FNN module
|
||||||
|
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||||
|
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||||
|
|
||||||
self.norm_final = BasicNorm(d_model)
|
self.ff_scale = 0.5
|
||||||
|
|
||||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||||
self.balancer = ActivationBalancer(channel_dim=-1,
|
self.norm_final = nn.LayerNorm(
|
||||||
min_positive=0.45,
|
d_model
|
||||||
max_positive=0.55,
|
) # for the final output of the block
|
||||||
max_positive=6.0)
|
|
||||||
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -215,10 +220,19 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
src,
|
src,
|
||||||
@ -227,15 +241,28 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = src + self.dropout(src_att)
|
src = residual + self.dropout(src_att)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(src))
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
src = residual + self.dropout(self.conv_module(src))
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
if self.normalize_before:
|
||||||
|
src = self.norm_final(src)
|
||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
@ -255,20 +282,12 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> out = conformer_encoder(src, pos_emb)
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int,
|
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||||
aux_layers: Sequence[int]) -> None:
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
|
||||||
assert num_layers - 1 not in aux_layers
|
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
num_channels = encoder_layer.d_model
|
|
||||||
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
|
|
||||||
final_weight=0.5,
|
|
||||||
pure_prob=0.333,
|
|
||||||
stddev=2.0)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -276,7 +295,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup_mode: bool = False
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -296,19 +314,14 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
outputs = []
|
for mod in self.layers:
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
if i in self.aux_layers:
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
output = self.combiner(outputs, warmup_mode)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -331,6 +344,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Construct an PositionalEncoding object."""
|
"""Construct an PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
|
self.xscale = math.sqrt(self.d_model)
|
||||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
@ -382,6 +396,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
|
x = x * self.xscale
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
@ -413,7 +428,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale_speed: float = 5.0
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RelPositionMultiheadAttention, self).__init__()
|
super(RelPositionMultiheadAttention, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
@ -424,29 +438,25 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.head_dim * num_heads == self.embed_dim
|
self.head_dim * num_heads == self.embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||||
self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=True, initial_scale=0.25)
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
# linear transformation for positional encoding.
|
# linear transformation for positional encoding.
|
||||||
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
|
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
# these two learnable bias are used in matrix c and matrix d
|
# these two learnable bias are used in matrix c and matrix d
|
||||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||||
self.scale_speed = scale_speed
|
|
||||||
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
|
|
||||||
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
def _pos_bias_u(self):
|
|
||||||
return self.pos_bias_u * (self.pos_bias_u_scale * self.scale_speed).exp()
|
|
||||||
|
|
||||||
def _pos_bias_v(self):
|
|
||||||
return self.pos_bias_v * (self.pos_bias_v_scale * self.scale_speed).exp()
|
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
nn.init.normal_(self.pos_bias_u, std=0.05)
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||||
nn.init.normal_(self.pos_bias_v, std=0.05)
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||||
|
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -506,11 +516,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
self.in_proj.get_weight(),
|
self.in_proj.weight,
|
||||||
self.in_proj.get_bias(),
|
self.in_proj.bias,
|
||||||
self.dropout,
|
self.dropout,
|
||||||
self.out_proj.get_weight(),
|
self.out_proj.weight,
|
||||||
self.out_proj.get_bias(),
|
self.out_proj.bias,
|
||||||
training=self.training,
|
training=self.training,
|
||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
@ -614,12 +624,13 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
assert (
|
assert (
|
||||||
head_dim * num_heads == embed_dim
|
head_dim * num_heads == embed_dim
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim must be divisible by num_heads"
|
||||||
|
|
||||||
scaling = float(head_dim) ** -0.5
|
scaling = float(head_dim) ** -0.5
|
||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
q, k, v = nn.functional.linear(
|
||||||
|
query, in_proj_weight, in_proj_bias
|
||||||
|
).chunk(3, dim=-1)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -651,7 +662,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
@ -670,7 +680,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
attn_mask.dtype == torch.float32
|
attn_mask.dtype == torch.float32
|
||||||
@ -720,7 +729,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||||
|
|
||||||
q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||||
|
|
||||||
@ -741,11 +750,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||||
|
|
||||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
) # (batch, head, time1, d_k)
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
q_with_bias_v = (q + self._pos_bias_v()).transpose(
|
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
) # (batch, head, time1, d_k)
|
) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
@ -765,7 +774,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
attn_output_weights = (
|
attn_output_weights = (
|
||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) # (batch, head, time1, time2)
|
) * scaling # (batch, head, time1, time2)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz * num_heads, tgt_len, -1
|
bsz * num_heads, tgt_len, -1
|
||||||
@ -840,7 +849,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
# kernerl_size should be a odd number for 'SAME' padding
|
# kernerl_size should be a odd number for 'SAME' padding
|
||||||
assert (kernel_size - 1) % 2 == 0
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
|
||||||
self.pointwise_conv1 = ScaledConv1d(
|
self.pointwise_conv1 = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
2 * channels,
|
2 * channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
@ -848,25 +857,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
# after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu).
|
|
||||||
# For most layers the normal rms value of channels of x seems to be in the range 1 to 4,
|
|
||||||
# but sometimes, for some reason, for layer 0 the rms ends up being very large,
|
|
||||||
# between 50 and 100 for different channels. This will cause very peaky and
|
|
||||||
# sparse derivatives for the sigmoid gating function, which will tend to make
|
|
||||||
# the loss function not learn effectively. (for most layers the average absolute values
|
|
||||||
# are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion,
|
|
||||||
# at the output of pointwise_conv1.output is around 0.35 to 0.45 for different
|
|
||||||
# layers, which likely breaks down as 0.5 for the "linear" half and
|
|
||||||
# 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we
|
|
||||||
# constrain the rms values to a reasonable range via a constraint of max_abs=10.0,
|
|
||||||
# it will be in a better position to start learning something, i.e. to latch onto
|
|
||||||
# the correct range.
|
|
||||||
self.deriv_balancer1 = ActivationBalancer(channel_dim=1, max_abs=10.0,
|
|
||||||
min_positive=0.05,
|
|
||||||
max_positive=1.0)
|
|
||||||
|
|
||||||
self.depthwise_conv = ScaledConv1d(
|
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
@ -875,22 +866,16 @@ class ConvolutionModule(nn.Module):
|
|||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
self.norm = nn.LayerNorm(channels)
|
||||||
self.deriv_balancer2 = ActivationBalancer(channel_dim=1,
|
self.pointwise_conv2 = nn.Conv1d(
|
||||||
min_positive=0.05,
|
|
||||||
max_positive=1.0)
|
|
||||||
|
|
||||||
self.activation = DoubleSwish()
|
|
||||||
|
|
||||||
self.pointwise_conv2 = ScaledConv1d(
|
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
initial_scale=0.25
|
|
||||||
)
|
)
|
||||||
|
self.activation = Swish()
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
@ -907,14 +892,15 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
# GLU mechanism
|
# GLU mechanism
|
||||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||||
|
|
||||||
x = self.deriv_balancer1(x)
|
|
||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
|
# x is (batch, channels, time)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
x = self.deriv_balancer2(x)
|
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||||
@ -922,197 +908,13 @@ class ConvolutionModule(nn.Module):
|
|||||||
return x.permute(2, 0, 1)
|
return x.permute(2, 0, 1)
|
||||||
|
|
||||||
|
|
||||||
class Identity(torch.nn.Module):
|
class Swish(torch.nn.Module):
|
||||||
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return x
|
"""Return Swich activation function."""
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
class RandomCombine(torch.nn.Module):
|
def identity(x):
|
||||||
"""
|
return x
|
||||||
This module combines a list of Tensors, all with the same shape, to
|
|
||||||
produce a single output of that same shape which, in training time,
|
|
||||||
is a random combination of all the inputs; but which in test time
|
|
||||||
will be just the last input.
|
|
||||||
|
|
||||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
|
||||||
conformer layers. This has a similar effect as iterated loss. (See:
|
|
||||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
|
||||||
NETWORKS).
|
|
||||||
"""
|
|
||||||
def __init__(self, num_inputs: int,
|
|
||||||
final_weight: float = 0.5,
|
|
||||||
pure_prob: float = 0.5,
|
|
||||||
stddev: float = 2.0) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
num_inputs: The number of tensor inputs, which equals the number of layers'
|
|
||||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
|
||||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
|
||||||
final_weight: The amount of weight or probability we assign to the
|
|
||||||
final layer when randomly choosing layers or when choosing
|
|
||||||
continuous layer weights.
|
|
||||||
pure_prob: The probability, on each frame, with which we choose
|
|
||||||
only a single layer to output (rather than an interpolation)
|
|
||||||
stddev: A standard deviation that we add to log-probs for computing
|
|
||||||
randomized weights.
|
|
||||||
|
|
||||||
The method of choosing which layers,
|
|
||||||
or combinations of layers, to use, is conceptually as follows.
|
|
||||||
With probability `pure_prob`:
|
|
||||||
With probability `final_weight`: choose final layer,
|
|
||||||
Else: choose random non-final layer.
|
|
||||||
Else:
|
|
||||||
Choose initial log-weights that correspond to assigning
|
|
||||||
weight `final_weight` to the final layer and equal
|
|
||||||
weights to other layers; then add Gaussian noise
|
|
||||||
with variance `stddev` to these log-weights, and normalize
|
|
||||||
to weights (note: the average weight assigned to the
|
|
||||||
final layer here will not be `final_weight` if stddev>0).
|
|
||||||
"""
|
|
||||||
super(RandomCombine, self).__init__()
|
|
||||||
assert pure_prob >= 0 and pure_prob <= 1
|
|
||||||
assert final_weight > 0 and final_weight < 1
|
|
||||||
assert num_inputs >= 1
|
|
||||||
|
|
||||||
self.num_inputs = num_inputs
|
|
||||||
self.final_weight = final_weight
|
|
||||||
self.pure_prob = pure_prob
|
|
||||||
self.stddev= stddev
|
|
||||||
|
|
||||||
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, inputs: Sequence[Tensor],
|
|
||||||
warmup_mode: bool) -> Tensor:
|
|
||||||
"""
|
|
||||||
Forward function.
|
|
||||||
Args:
|
|
||||||
inputs: a list of Tensor, e.g. from various layers of a transformer.
|
|
||||||
All must be the same shape, of (*, num_channels)
|
|
||||||
Returns:
|
|
||||||
a Tensor of shape (*, num_channels). In test mode
|
|
||||||
this is just the final input.
|
|
||||||
"""
|
|
||||||
num_inputs = self.num_inputs
|
|
||||||
assert len(inputs) == num_inputs
|
|
||||||
if not (self.training and warmup_mode):
|
|
||||||
return inputs[-1]
|
|
||||||
|
|
||||||
# Shape of weights: (*, num_inputs)
|
|
||||||
num_channels = inputs[0].shape[-1]
|
|
||||||
num_frames = inputs[0].numel() // num_channels
|
|
||||||
|
|
||||||
ndim = inputs[0].ndim
|
|
||||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
|
||||||
stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames,
|
|
||||||
num_channels,
|
|
||||||
num_inputs))
|
|
||||||
|
|
||||||
# weights: (num_frames, num_inputs)
|
|
||||||
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
|
|
||||||
num_frames)
|
|
||||||
|
|
||||||
weights = weights.reshape(num_frames, num_inputs, 1)
|
|
||||||
# ans: (num_frames, num_channels, 1)
|
|
||||||
ans = torch.matmul(stacked_inputs, weights)
|
|
||||||
# ans: (*, num_channels)
|
|
||||||
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# for testing only...
|
|
||||||
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor:
|
|
||||||
"""
|
|
||||||
Return a tensor of random weights, of shape (num_frames, self.num_inputs),
|
|
||||||
Args:
|
|
||||||
dtype: the data-type desired for the answer, e.g. float, double
|
|
||||||
device: the device needed for the answer
|
|
||||||
num_frames: the number of sets of weights desired
|
|
||||||
Returns: a tensor of shape (num_frames, self.num_inputs), such that
|
|
||||||
ans.sum(dim=1) is all ones.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pure_prob = self.pure_prob
|
|
||||||
if pure_prob == 0.0:
|
|
||||||
return self._get_random_mixed_weights(dtype, device, num_frames)
|
|
||||||
elif pure_prob == 1.0:
|
|
||||||
return self._get_random_pure_weights(dtype, device, num_frames)
|
|
||||||
else:
|
|
||||||
p = self._get_random_pure_weights(dtype, device, num_frames)
|
|
||||||
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
|
||||||
return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m)
|
|
||||||
|
|
||||||
def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
|
|
||||||
"""
|
|
||||||
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
|
|
||||||
Args:
|
|
||||||
dtype: the data-type desired for the answer, e.g. float, double
|
|
||||||
device: the device needed for the answer
|
|
||||||
num_frames: the number of sets of weights desired
|
|
||||||
Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with
|
|
||||||
exactly one weight equal to 1.0 on each frame.
|
|
||||||
"""
|
|
||||||
|
|
||||||
final_prob = self.final_weight
|
|
||||||
|
|
||||||
# final contains self.num_inputs - 1 in all elements
|
|
||||||
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
|
||||||
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
|
|
||||||
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
|
|
||||||
|
|
||||||
indexes = torch.where(torch.rand(num_frames, device=device) < final_prob,
|
|
||||||
final, nonfinal)
|
|
||||||
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
|
|
||||||
"""
|
|
||||||
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
|
|
||||||
Args:
|
|
||||||
dtype: the data-type desired for the answer, e.g. float, double
|
|
||||||
device: the device needed for the answer
|
|
||||||
num_frames: the number of sets of weights desired
|
|
||||||
Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that
|
|
||||||
sum to one over the second axis, i.e. ans.sum(dim=1) is all ones.
|
|
||||||
"""
|
|
||||||
logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev
|
|
||||||
logprobs[:,-1] += self.final_log_weight
|
|
||||||
return logprobs.softmax(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
|
||||||
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
|
|
||||||
num_inputs = 3
|
|
||||||
num_channels = 50
|
|
||||||
m = RandomCombine(num_inputs=num_inputs,
|
|
||||||
final_weight=final_weight,
|
|
||||||
pure_prob=pure_prob,
|
|
||||||
stddev=stddev)
|
|
||||||
|
|
||||||
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
|
|
||||||
|
|
||||||
y = m(x, True)
|
|
||||||
assert y.shape == x[0].shape
|
|
||||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
_test_random_combine(0.999, 0, 0.0)
|
|
||||||
_test_random_combine(0.5, 0, 0.0)
|
|
||||||
_test_random_combine(0.999, 0, 0.0)
|
|
||||||
_test_random_combine(0.5, 0, 0.3)
|
|
||||||
_test_random_combine(0.5, 1, 0.3)
|
|
||||||
_test_random_combine(0.5, 0.5, 0.3)
|
|
||||||
|
|
||||||
feature_dim = 50
|
|
||||||
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
|
||||||
batch_size = 5
|
|
||||||
seq_len = 20
|
|
||||||
# Just make sure the forward pass runs.
|
|
||||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
|
||||||
warmup_mode=True)
|
|
||||||
|
@ -17,9 +17,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
|
||||||
from typing import Optional
|
|
||||||
from subsampling import ScaledConv1d
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
@ -55,7 +52,7 @@ class Decoder(nn.Module):
|
|||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = ScaledEmbedding(
|
self.embedding = nn.Embedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=embedding_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
@ -65,7 +62,7 @@ class Decoder(nn.Module):
|
|||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
if context_size > 1:
|
if context_size > 1:
|
||||||
self.conv = ScaledConv1d(
|
self.conv = nn.Conv1d(
|
||||||
in_channels=embedding_dim,
|
in_channels=embedding_dim,
|
||||||
out_channels=embedding_dim,
|
out_channels=embedding_dim,
|
||||||
kernel_size=context_size,
|
kernel_size=context_size,
|
||||||
@ -85,7 +82,6 @@ class Decoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
y = y.to(torch.int64)
|
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
@ -100,139 +96,3 @@ class Decoder(nn.Module):
|
|||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ScaledEmbedding(nn.Module):
|
|
||||||
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
|
||||||
|
|
||||||
This module is often used to store word embeddings and retrieve them using indices.
|
|
||||||
The input to the module is a list of indices, and the output is the corresponding
|
|
||||||
word embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_embeddings (int): size of the dictionary of embeddings
|
|
||||||
embedding_dim (int): the size of each embedding vector
|
|
||||||
padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx`
|
|
||||||
(initialized to zeros) whenever it encounters the index.
|
|
||||||
max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
|
|
||||||
is renormalized to have norm :attr:`max_norm`.
|
|
||||||
norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
|
|
||||||
scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of
|
|
||||||
the words in the mini-batch. Default ``False``.
|
|
||||||
sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
|
|
||||||
See Notes for more details regarding sparse gradients.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
|
||||||
initialized from :math:`\mathcal{N}(0, 1)`
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
|
||||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
Keep in mind that only a limited number of optimizers support
|
|
||||||
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
|
||||||
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
With :attr:`padding_idx` set, the embedding vector at
|
|
||||||
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
|
||||||
vector can be modified afterwards, e.g., using a customized
|
|
||||||
initialization method, and thus changing the vector used to pad the
|
|
||||||
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
|
||||||
is always zero.
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> # an Embedding module containing 10 tensors of size 3
|
|
||||||
>>> embedding = nn.Embedding(10, 3)
|
|
||||||
>>> # a batch of 2 samples of 4 indices each
|
|
||||||
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
|
|
||||||
>>> embedding(input)
|
|
||||||
tensor([[[-0.0251, -1.6902, 0.7172],
|
|
||||||
[-0.6431, 0.0748, 0.6969],
|
|
||||||
[ 1.4970, 1.3448, -0.9685],
|
|
||||||
[-0.3677, -2.7265, -0.1685]],
|
|
||||||
|
|
||||||
[[ 1.4970, 1.3448, -0.9685],
|
|
||||||
[ 0.4362, -0.4004, 0.9400],
|
|
||||||
[-0.6431, 0.0748, 0.6969],
|
|
||||||
[ 0.9124, -2.3616, 1.1151]]])
|
|
||||||
|
|
||||||
|
|
||||||
>>> # example with padding_idx
|
|
||||||
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
|
|
||||||
>>> input = torch.LongTensor([[0,2,0,5]])
|
|
||||||
>>> embedding(input)
|
|
||||||
tensor([[[ 0.0000, 0.0000, 0.0000],
|
|
||||||
[ 0.1535, -2.0309, 0.9315],
|
|
||||||
[ 0.0000, 0.0000, 0.0000],
|
|
||||||
[-0.1655, 0.9897, 0.0635]]])
|
|
||||||
"""
|
|
||||||
__constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx',
|
|
||||||
'scale_grad_by_freq', 'sparse']
|
|
||||||
|
|
||||||
num_embeddings: int
|
|
||||||
embedding_dim: int
|
|
||||||
padding_idx: int
|
|
||||||
scale_grad_by_freq: bool
|
|
||||||
weight: Tensor
|
|
||||||
sparse: bool
|
|
||||||
|
|
||||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
|
||||||
scale_grad_by_freq: bool = False,
|
|
||||||
sparse: bool = False,
|
|
||||||
scale_speed: float = 5.0) -> None:
|
|
||||||
super(ScaledEmbedding, self).__init__()
|
|
||||||
self.num_embeddings = num_embeddings
|
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
if padding_idx is not None:
|
|
||||||
if padding_idx > 0:
|
|
||||||
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
||||||
elif padding_idx < 0:
|
|
||||||
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
||||||
padding_idx = self.num_embeddings + padding_idx
|
|
||||||
self.padding_idx = padding_idx
|
|
||||||
self.scale_grad_by_freq = scale_grad_by_freq
|
|
||||||
|
|
||||||
self.scale_speed = scale_speed
|
|
||||||
self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters()
|
|
||||||
self.sparse = sparse
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim))
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
|
||||||
nn.init.normal_(self.weight, std=0.05)
|
|
||||||
nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed)
|
|
||||||
|
|
||||||
if self.padding_idx is not None:
|
|
||||||
with torch.no_grad():
|
|
||||||
self.weight[self.padding_idx].fill_(0)
|
|
||||||
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
scale = (self.scale * self.scale_speed).exp()
|
|
||||||
if input.numel() < self.num_embeddings:
|
|
||||||
return F.embedding(
|
|
||||||
input, self.weight, self.padding_idx,
|
|
||||||
None, 2.0, # None, 2.0 relate to normalization
|
|
||||||
self.scale_grad_by_freq, self.sparse) * scale
|
|
||||||
else:
|
|
||||||
return F.embedding(
|
|
||||||
input, self.weight * scale, self.padding_idx,
|
|
||||||
None, 2.0, # None, 2.0 relates to normalization
|
|
||||||
self.scale_grad_by_freq, self.sparse)
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}'
|
|
||||||
if self.padding_idx is not None:
|
|
||||||
s += ', padding_idx={padding_idx}'
|
|
||||||
if self.scale_grad_by_freq is not False:
|
|
||||||
s += ', scale_grad_by_freq={scale_grad_by_freq}'
|
|
||||||
if self.sparse is not False:
|
|
||||||
s += ', sparse=True'
|
|
||||||
return s.format(**self.__dict__)
|
|
||||||
|
@ -22,7 +22,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
class EncoderInterface(nn.Module):
|
class EncoderInterface(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from subsampling import ScaledLinear
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
@ -24,7 +24,7 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
self.input_dim = input_dim
|
self.input_dim = input_dim
|
||||||
self.output_dim = output_dim
|
self.output_dim = output_dim
|
||||||
self.output_linear = ScaledLinear(input_dim, output_dim)
|
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -65,7 +65,6 @@ class Transducer(nn.Module):
|
|||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
modified_transducer_prob: float = 0.0,
|
modified_transducer_prob: float = 0.0,
|
||||||
warmup_mode: bool = False
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -88,7 +87,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode)
|
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
@ -111,8 +111,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
# was 2c_maxabs1000_maxp0.95_noexp_convderiv2warmup_scale_0mean, then reworking initialization..
|
default="transducer_stateless/exp",
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2d",
|
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -223,7 +222,6 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
"warmup_minibatches": 3000, # use warmup mode for 3k minibatches.
|
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
@ -381,7 +379,6 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
is_warmup_mode: bool = False
|
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -418,7 +415,6 @@ def compute_loss(
|
|||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
modified_transducer_prob=params.modified_transducer_prob,
|
modified_transducer_prob=params.modified_transducer_prob,
|
||||||
warmup_mode=is_warmup_mode
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
@ -455,7 +451,6 @@ def compute_validation_loss(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
is_warmup_mode=False
|
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
@ -517,7 +512,6 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
is_warmup_mode=(params.batch_idx_train<params.warmup_minibatches)
|
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -750,7 +744,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
is_warmup_mode=False
|
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
@ -21,7 +21,7 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling, ScaledLinear
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ class Transformer(EncoderInterface):
|
|||||||
|
|
||||||
# TODO(fangjun): remove dropout
|
# TODO(fangjun): remove dropout
|
||||||
self.encoder_output_layer = nn.Sequential(
|
self.encoder_output_layer = nn.Sequential(
|
||||||
nn.Dropout(p=dropout), ScaledLinear(d_model, output_dim)
|
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user