mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 18:54:18 +00:00
add attention-decoder loss option for zipformer recipe
This commit is contained in:
parent
231bbcd2b6
commit
d17535f4cd
485
egs/librispeech/ASR/zipformer/attention_decoder.py
Normal file
485
egs/librispeech/ASR/zipformer/attention_decoder.py
Normal file
@ -0,0 +1,485 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# The model structure is modified from Daniel Povey's Zipformer
|
||||||
|
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from label_smoothing import LabelSmoothingLoss
|
||||||
|
|
||||||
|
from icefall.utils import add_eos, add_sos, make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionDecoderModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
vocab_size (int): Number of classes.
|
||||||
|
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
|
||||||
|
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||||
|
nhead (int, int): number of heads
|
||||||
|
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||||
|
num_encoder_layers (int): number of encoder layers
|
||||||
|
dropout (float): dropout rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
decoder_dim: int = 512,
|
||||||
|
num_decoder_layers: int = 6,
|
||||||
|
attention_dim: int = 512,
|
||||||
|
nhead: int = 8,
|
||||||
|
feedforward_dim: int = 2048,
|
||||||
|
sos_id: int = 1,
|
||||||
|
eos_id: int = 1,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
ignore_id: int = -1,
|
||||||
|
label_smoothing: float = 0.1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.eos_id = eos_id
|
||||||
|
self.sos_id = sos_id
|
||||||
|
self.ignore_id = ignore_id
|
||||||
|
|
||||||
|
# For the segment of the warmup period, we let the Embedding
|
||||||
|
# layer learn something. Then we start to warm up the other encoders.
|
||||||
|
self.decoder = TransformerDecoder(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
d_model=decoder_dim,
|
||||||
|
num_decoder_layers=num_decoder_layers,
|
||||||
|
attention_dim=attention_dim,
|
||||||
|
nhead=nhead,
|
||||||
|
feedforward_dim=feedforward_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used to calculate attention-decoder loss
|
||||||
|
self.loss_fun = LabelSmoothingLoss(
|
||||||
|
ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor):
|
||||||
|
"""Prepare ys_in_pad and ys_out_pad."""
|
||||||
|
ys_in = add_sos(ys, sos_id=self.sos_id)
|
||||||
|
# [B, S+1], start with SOS
|
||||||
|
ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id)
|
||||||
|
ys_in_lens = ys_lens + 1
|
||||||
|
|
||||||
|
ys_out = add_eos(ys, eos_id=self.eos_id)
|
||||||
|
# [B, S+1], end with EOS
|
||||||
|
ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id)
|
||||||
|
|
||||||
|
return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64)
|
||||||
|
|
||||||
|
def calc_att_loss(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
ys: k2.RaggedTensor,
|
||||||
|
ys_lens: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate attention-decoder loss.
|
||||||
|
Args:
|
||||||
|
encoder_out: (batch, num_frames, encoder_dim)
|
||||||
|
encoder_out_lens: (batch,)
|
||||||
|
token_ids: A list of token id list.
|
||||||
|
|
||||||
|
Return: The attention-decoder loss.
|
||||||
|
"""
|
||||||
|
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
|
||||||
|
|
||||||
|
# decoder forward
|
||||||
|
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
|
||||||
|
|
||||||
|
loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def nll(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
token_ids: List[List[int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute negative log likelihood(nll) from attention-decoder.
|
||||||
|
Args:
|
||||||
|
encoder_out: (batch, num_frames, encoder_dim)
|
||||||
|
encoder_out_lens: (batch,)
|
||||||
|
token_ids: A list of token id list.
|
||||||
|
|
||||||
|
Return: A tensor of shape (batch, num_tokens).
|
||||||
|
"""
|
||||||
|
ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device)
|
||||||
|
row_splits = ys.shape.row_splits(1)
|
||||||
|
ys_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
|
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens)
|
||||||
|
|
||||||
|
# decoder forward
|
||||||
|
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
|
||||||
|
|
||||||
|
batch_size, _, num_classes = decoder_out.size()
|
||||||
|
nll = nn.functional.cross_entropy(
|
||||||
|
decoder_out.view(-1, num_classes),
|
||||||
|
ys_out_pad.view(-1),
|
||||||
|
ignore_index=self.ignore_id,
|
||||||
|
reduction="none",
|
||||||
|
)
|
||||||
|
nll = nll.view(batch_size, -1)
|
||||||
|
return nll
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
"""Transfomer decoder module.
|
||||||
|
It is modified from https://github.com/espnet/espnet/blob/master/espnet2/asr/decoder/transformer_decoder.py.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: output dim
|
||||||
|
d_model: decoder dimension
|
||||||
|
num_decoder_layers: number of decoder layers
|
||||||
|
attention_dim: total dimension of multi head attention
|
||||||
|
n_head: number of attention heads
|
||||||
|
feedforward_dim: hidden dimension of feed_forward module
|
||||||
|
dropout: dropout rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
d_model: int = 512,
|
||||||
|
num_decoder_layers: int = 6,
|
||||||
|
attention_dim: int = 512,
|
||||||
|
nhead: int = 8,
|
||||||
|
feedforward_dim: int = 2048,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
|
||||||
|
|
||||||
|
# Using absolute positional encoding
|
||||||
|
self.pos = PositionalEncoding(d_model, dropout_rate=0.1)
|
||||||
|
|
||||||
|
self.num_layers = num_decoder_layers
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout)
|
||||||
|
for _ in range(num_decoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.output_layer = nn.Linear(d_model, vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_lens: torch.Tensor,
|
||||||
|
ys_in_pad: torch.Tensor,
|
||||||
|
ys_in_lens: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Forward decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory: encoded memory, (batch, maxlen_in, feat)
|
||||||
|
memory_lens: (batch,)
|
||||||
|
ys_in_pad: input token ids, (batch, maxlen_out)
|
||||||
|
ys_in_lens: (batch, )
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tgt: decoded token score before softmax (batch, maxlen_out, vocab_size)
|
||||||
|
"""
|
||||||
|
tgt = ys_in_pad
|
||||||
|
# tgt_mask: (B, 1, L)
|
||||||
|
tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
|
||||||
|
# m: (1, L, L)
|
||||||
|
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||||
|
# tgt_mask: (B, L, L)
|
||||||
|
tgt_mask = tgt_mask | (~m)
|
||||||
|
|
||||||
|
memory_mask = make_pad_mask(memory_lens)[:, None, :].to(memory.device)
|
||||||
|
|
||||||
|
tgt = self.embed(tgt)
|
||||||
|
tgt = self.pos(tgt)
|
||||||
|
|
||||||
|
for i, mod in enumerate(self.layers):
|
||||||
|
tgt = mod(tgt, tgt_mask, memory, memory_mask)
|
||||||
|
|
||||||
|
tgt = self.output_layer(tgt)
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
"""Single decoder layer module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: equal to encoder_dim
|
||||||
|
attention_dim: total dimension of multi head attention
|
||||||
|
n_head: number of attention heads
|
||||||
|
feedforward_dim: hidden dimension of feed_forward module
|
||||||
|
dropout: dropout rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 512,
|
||||||
|
attention_dim: int = 512,
|
||||||
|
nhead: int = 8,
|
||||||
|
feedforward_dim: int = 2048,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
"""Construct an DecoderLayer object."""
|
||||||
|
super(DecoderLayer, self).__init__()
|
||||||
|
|
||||||
|
self.norm_self_attn = nn.LayerNorm(d_model)
|
||||||
|
self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||||
|
|
||||||
|
self.norm_src_attn = nn.LayerNorm(d_model)
|
||||||
|
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||||
|
|
||||||
|
self.norm_ff = nn.LayerNorm(d_model)
|
||||||
|
self.feed_forward = nn.Sequential(
|
||||||
|
nn.Linear(d_model, feedforward_dim),
|
||||||
|
Swish(),
|
||||||
|
nn.Dropout(dropout),
|
||||||
|
nn.Linear(feedforward_dim, d_model),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, tgt, tgt_mask, memory, memory_mask):
|
||||||
|
"""Compute decoded features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||||
|
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||||
|
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||||
|
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||||
|
"""
|
||||||
|
# self-attn module
|
||||||
|
tgt_norm = self.norm_self_attn(tgt)
|
||||||
|
tgt = tgt + self.dropout(self.self_attn(tgt_norm, tgt_norm, tgt_norm, tgt_mask))
|
||||||
|
|
||||||
|
# cross-attn module
|
||||||
|
tgt = tgt + self.dropout(self.src_attn(self.norm_src_attn(tgt), memory, memory, memory_mask))
|
||||||
|
|
||||||
|
# feed-forward module
|
||||||
|
tgt = tgt + self.dropout(self.feed_forward(self.norm_ff(tgt)))
|
||||||
|
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadedAttention(nn.Module):
|
||||||
|
"""Multi-Head Attention layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model.
|
||||||
|
attention_dim: dimension in the attention module, may be less or more than embed_dim
|
||||||
|
but must be a multiple of num_heads.
|
||||||
|
num_heads: parallel attention heads.
|
||||||
|
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0
|
||||||
|
):
|
||||||
|
"""Construct an MultiHeadedAttention object."""
|
||||||
|
super(MultiHeadedAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.attention_dim = attention_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = attention_dim // num_heads
|
||||||
|
assert self.head_dim * num_heads == attention_dim, (
|
||||||
|
self.head_dim,
|
||||||
|
num_heads,
|
||||||
|
attention_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||||
|
self.linear_k = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||||
|
self.linear_v = nn.Linear(embed_dim, attention_dim, bias=True)
|
||||||
|
self.scale = math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True)
|
||||||
|
|
||||||
|
def forward(self, query, key, value, mask):
|
||||||
|
"""Compute scaled dot product attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||||
|
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||||
|
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||||
|
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||||
|
(#batch, time1, time2).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||||
|
|
||||||
|
"""
|
||||||
|
bsz, tgt_len, _ = query.size()
|
||||||
|
src_len = key.size(1)
|
||||||
|
num_heads = self.num_heads
|
||||||
|
head_dim = self.head_dim
|
||||||
|
|
||||||
|
q = self.linear_q(query)
|
||||||
|
k = self.linear_k(key)
|
||||||
|
v = self.linear_v(value)
|
||||||
|
|
||||||
|
q = q.reshape(bsz, tgt_len, num_heads, head_dim)
|
||||||
|
q = q.transpose(1, 2) # (batch, head, time1, head_dim)
|
||||||
|
k = k.reshape(bsz, src_len, num_heads, head_dim)
|
||||||
|
k = k.permute(0, 2, 3, 1) # (batch, head, head_dim, time2)
|
||||||
|
v = v.reshape(bsz, src_len, num_heads, head_dim)
|
||||||
|
v = v.transpose(1, 2).reshape(bsz * num_heads, src_len, head_dim)
|
||||||
|
|
||||||
|
# (batch, head, time1, time2)
|
||||||
|
attn_output_weights = torch.matmul(q, k) / self.scale
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
attn_output_weights = attn_output_weights.masked_fill(
|
||||||
|
mask.unsqueeze(1), float("-inf")
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||||
|
attn_output_weights = nn.functional.dropout(
|
||||||
|
attn_output_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
|
||||||
|
# (bsz * head, time1, head_dim_v)
|
||||||
|
attn_output = torch.bmm(attn_output_weights, v)
|
||||||
|
assert attn_output.shape == (bsz * num_heads, tgt_len, head_dim)
|
||||||
|
attn_output = (
|
||||||
|
attn_output.reshape(bsz, num_heads, tgt_len, head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.reshape(bsz, tgt_len, self.attention_dim)
|
||||||
|
)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
"""Positional encoding.
|
||||||
|
Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model (int): Embedding dimension.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
max_len (int): Maximum input length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||||
|
"""Construct an PositionalEncoding object."""
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.xscale = math.sqrt(self.d_model)
|
||||||
|
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||||
|
self.pe = None
|
||||||
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
|
def extend_pe(self, x):
|
||||||
|
"""Reset the positional encodings."""
|
||||||
|
if self.pe is not None:
|
||||||
|
if self.pe.size(1) >= x.size(1):
|
||||||
|
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||||
|
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||||
|
return
|
||||||
|
pe = torch.zeros(x.size(1), self.d_model)
|
||||||
|
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||||
|
div_term = torch.exp(
|
||||||
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
|
* -(math.log(10000.0) / self.d_model)
|
||||||
|
)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""Add positional encoding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
|
"""
|
||||||
|
self.extend_pe(x)
|
||||||
|
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(torch.nn.Module):
|
||||||
|
"""Construct an Swish object."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return Swich activation function."""
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
||||||
|
"""Create mask for subsequent steps (size, size).
|
||||||
|
|
||||||
|
:param int size: size of mask
|
||||||
|
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||||
|
:param torch.dtype dtype: result dtype
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
>>> subsequent_mask(3)
|
||||||
|
[[1, 0, 0],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 1, 1]]
|
||||||
|
"""
|
||||||
|
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||||
|
return torch.tril(ret, out=ret)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_attention_decoder_model():
|
||||||
|
m = AttentionDecoderModel(
|
||||||
|
vocab_size=500,
|
||||||
|
decoder_dim=512,
|
||||||
|
num_decoder_layers=6,
|
||||||
|
attention_dim=512,
|
||||||
|
nhead=8,
|
||||||
|
feedforward_dim=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
sos_id=1,
|
||||||
|
eos_id=1,
|
||||||
|
ignore_id=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in m.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
m.eval()
|
||||||
|
encoder_out = torch.randn(2, 50, 512)
|
||||||
|
encoder_out_lens = torch.full((2,), 50)
|
||||||
|
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||||
|
|
||||||
|
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
|
||||||
|
print(nll)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_attention_decoder_model()
|
109
egs/librispeech/ASR/zipformer/label_smoothing.py
Normal file
109
egs/librispeech/ASR/zipformer/label_smoothing.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LabelSmoothingLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implement the LabelSmoothingLoss proposed in the following paper
|
||||||
|
https://arxiv.org/pdf/1512.00567.pdf
|
||||||
|
(Rethinking the Inception Architecture for Computer Vision)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_index: int = -1,
|
||||||
|
label_smoothing: float = 0.1,
|
||||||
|
reduction: str = "sum",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ignore_index:
|
||||||
|
ignored class id
|
||||||
|
label_smoothing:
|
||||||
|
smoothing rate (0.0 means the conventional cross entropy loss)
|
||||||
|
reduction:
|
||||||
|
It has the same meaning as the reduction in
|
||||||
|
`torch.nn.CrossEntropyLoss`. It can be one of the following three
|
||||||
|
values: (1) "none": No reduction will be applied. (2) "mean": the
|
||||||
|
mean of the output is taken. (3) "sum": the output will be summed.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
|
||||||
|
assert reduction in ("none", "sum", "mean"), reduction
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute loss between x and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
prediction of dimension
|
||||||
|
(batch_size, input_length, number_of_classes).
|
||||||
|
target:
|
||||||
|
target masked with self.ignore_index of
|
||||||
|
dimension (batch_size, input_length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar tensor containing the loss without normalization.
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3
|
||||||
|
assert target.ndim == 2
|
||||||
|
assert x.shape[:2] == target.shape
|
||||||
|
num_classes = x.size(-1)
|
||||||
|
x = x.reshape(-1, num_classes)
|
||||||
|
# Now x is of shape (N*T, C)
|
||||||
|
|
||||||
|
# We don't want to change target in-place below,
|
||||||
|
# so we make a copy of it here
|
||||||
|
target = target.clone().reshape(-1)
|
||||||
|
|
||||||
|
ignored = target == self.ignore_index
|
||||||
|
|
||||||
|
# See https://github.com/k2-fsa/icefall/issues/240
|
||||||
|
# and https://github.com/k2-fsa/icefall/issues/297
|
||||||
|
# for why we don't use target[ignored] = 0 here
|
||||||
|
target = torch.where(ignored, torch.zeros_like(target), target)
|
||||||
|
|
||||||
|
true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)
|
||||||
|
|
||||||
|
true_dist = (
|
||||||
|
true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the value of ignored indexes to 0
|
||||||
|
#
|
||||||
|
# See https://github.com/k2-fsa/icefall/issues/240
|
||||||
|
# and https://github.com/k2-fsa/icefall/issues/297
|
||||||
|
# for why we don't use true_dist[ignored] = 0 here
|
||||||
|
true_dist = torch.where(
|
||||||
|
ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
|
||||||
|
torch.zeros_like(true_dist),
|
||||||
|
true_dist,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
|
||||||
|
if self.reduction == "sum":
|
||||||
|
return loss.sum()
|
||||||
|
elif self.reduction == "mean":
|
||||||
|
return loss.sum() / (~ignored).sum()
|
||||||
|
else:
|
||||||
|
return loss.sum(dim=-1)
|
@ -34,11 +34,13 @@ class AsrModel(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: Optional[nn.Module] = None,
|
decoder: Optional[nn.Module] = None,
|
||||||
joiner: Optional[nn.Module] = None,
|
joiner: Optional[nn.Module] = None,
|
||||||
|
attention_decoder: Optional[nn.Module] = None,
|
||||||
encoder_dim: int = 384,
|
encoder_dim: int = 384,
|
||||||
decoder_dim: int = 512,
|
decoder_dim: int = 512,
|
||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
use_transducer: bool = True,
|
use_transducer: bool = True,
|
||||||
use_ctc: bool = False,
|
use_ctc: bool = False,
|
||||||
|
use_attention_decoder: bool = False,
|
||||||
):
|
):
|
||||||
"""A joint CTC & Transducer ASR model.
|
"""A joint CTC & Transducer ASR model.
|
||||||
|
|
||||||
@ -111,6 +113,12 @@ class AsrModel(nn.Module):
|
|||||||
nn.LogSoftmax(dim=-1),
|
nn.LogSoftmax(dim=-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_attention_decoder = use_attention_decoder
|
||||||
|
if use_attention_decoder:
|
||||||
|
self.attention_decoder = attention_decoder
|
||||||
|
else:
|
||||||
|
assert attention_decoder is None
|
||||||
|
|
||||||
def forward_encoder(
|
def forward_encoder(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -286,7 +294,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -308,7 +316,7 @@ class AsrModel(nn.Module):
|
|||||||
part
|
part
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer losses and CTC loss,
|
Return the transducer losses and CTC loss,
|
||||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
@ -322,6 +330,8 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||||
|
|
||||||
|
device = x.device
|
||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
@ -333,7 +343,7 @@ class AsrModel(nn.Module):
|
|||||||
simple_loss, pruned_loss = self.forward_transducer(
|
simple_loss, pruned_loss = self.forward_transducer(
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
y=y.to(x.device),
|
y=y.to(device),
|
||||||
y_lens=y_lens,
|
y_lens=y_lens,
|
||||||
prune_range=prune_range,
|
prune_range=prune_range,
|
||||||
am_scale=am_scale,
|
am_scale=am_scale,
|
||||||
@ -355,4 +365,14 @@ class AsrModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
ctc_loss = torch.empty(0)
|
ctc_loss = torch.empty(0)
|
||||||
|
|
||||||
return simple_loss, pruned_loss, ctc_loss
|
if self.use_attention_decoder:
|
||||||
|
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
ys=y.to(device),
|
||||||
|
ys_lens=y_lens.to(device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_decoder_loss = torch.empty(0)
|
||||||
|
|
||||||
|
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
|
||||||
|
@ -66,6 +66,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from attention_decoder import AttentionDecoderModel
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -220,6 +221,41 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="""Dimension used in the attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-num-layers",
|
||||||
|
type=int,
|
||||||
|
default=6,
|
||||||
|
help="""Number of transformer layers used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-attention-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="""Attention dimension used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-num-heads",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="""Number of attention heads used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-feedforward-dim",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="""Feedforward dimension used in attention decoder""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal",
|
"--causal",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -258,6 +294,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="If True, use CTC head.",
|
help="If True, use CTC head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-attention-decoder",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use attention-decoder head.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -403,6 +446,21 @@ def get_parser():
|
|||||||
help="Scale for CTC loss.",
|
help="Scale for CTC loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="Scale for attention-decoder loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label-smoothing",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""Label smoothing rate used in attention decoder,
|
||||||
|
(0.0 means the conventional cross entropy loss)""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -531,6 +589,8 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
|
# parameters for attention-decoder
|
||||||
|
"ignore_id": -1,
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
@ -603,6 +663,25 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
|
encoder_dim = max(_to_int_tuple(params.encoder_dim))
|
||||||
|
assert params.attention_decoder_dim == encoder_dim, (params.attention_decoder_dim, encoder_dim)
|
||||||
|
|
||||||
|
decoder = AttentionDecoderModel(
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
|
decoder_dim=params.attention_decoder_dim,
|
||||||
|
num_decoder_layers=params.attention_decoder_num_layers,
|
||||||
|
attention_dim=params.attention_decoder_attention_dim,
|
||||||
|
nhead=params.attention_decoder_num_heads,
|
||||||
|
feedforward_dim=params.attention_decoder_feedforward_dim,
|
||||||
|
sos_id=params.sos_id,
|
||||||
|
eos_id=params.eos_id,
|
||||||
|
ignore_id=params.ignore_id,
|
||||||
|
label_smoothing=params.label_smoothing,
|
||||||
|
)
|
||||||
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert params.use_transducer or params.use_ctc, (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
f"At least one of them should be True, "
|
f"At least one of them should be True, "
|
||||||
@ -620,16 +699,23 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder = None
|
decoder = None
|
||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
|
if params.use_attention_decoder:
|
||||||
|
attention_decoder = get_attention_decoder_model(params)
|
||||||
|
else:
|
||||||
|
attention_decoder = None
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
attention_decoder=attention_decoder,
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
use_transducer=params.use_transducer,
|
||||||
use_ctc=params.use_ctc,
|
use_ctc=params.use_ctc,
|
||||||
|
use_attention_decoder=params.use_attention_decoder,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -792,7 +878,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -822,6 +908,9 @@ def compute_loss(
|
|||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
|
||||||
|
if params.use_attention_decoder:
|
||||||
|
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -836,6 +925,8 @@ def compute_loss(
|
|||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
if params.use_attention_decoder:
|
||||||
|
info["attn_deocder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -1114,10 +1205,17 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.eos_id = sp.piece_to_id("<sos/eos>")
|
||||||
|
params.sos_id = sp.piece_to_id("<sos/eos>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if not params.use_transducer:
|
if not params.use_transducer:
|
||||||
params.ctc_loss_scale = 1.0
|
if not params.use_attention_decoder:
|
||||||
|
params.ctc_loss_scale = 1.0
|
||||||
|
else:
|
||||||
|
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
|
||||||
|
params.ctc_loss_scale, params.attention_decoder_loss_scale
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user