mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Merge 3364d9863c7816da59b9df30f58784617ab23b84 into ff2bef9e501a4b5ebfec04cbfe8afa2e8bea4b40
This commit is contained in:
commit
ad4615eba3
0
egs/librispeech/ASR/zipformer_ctc_attn/__init__.py
Normal file
0
egs/librispeech/ASR/zipformer_ctc_attn/__init__.py
Normal file
1
egs/librispeech/ASR/zipformer_ctc_attn/asr_datamodule.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/asr_datamodule.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/asr_datamodule.py
|
746
egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py
Normal file
746
egs/librispeech/ASR/zipformer_ctc_attn/attention_decoder.py
Normal file
@ -0,0 +1,746 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# The model structure is modified from Daniel Povey's Zipformer
|
||||
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
Identity,
|
||||
MaxEig,
|
||||
ScaledConv1d,
|
||||
ScaledLinear,
|
||||
Whiten,
|
||||
_diag,
|
||||
penalize_abs_values_gt,
|
||||
random_clamp,
|
||||
softmax,
|
||||
)
|
||||
from zipformer import FeedforwardModule
|
||||
|
||||
from icefall.utils import add_eos, add_sos, make_pad_mask
|
||||
|
||||
|
||||
class AttentionDecoderModel(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
vocab_size (int): Number of classes.
|
||||
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
|
||||
attention_dim: (int,int): attention dimension of 2 encoder stacks
|
||||
nhead (int, int): number of heads
|
||||
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
warmup_batches (float): number of batches to warm up over
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
unmasked_dim: int,
|
||||
num_decoder_layers: int,
|
||||
attention_dim: int,
|
||||
nhead: int,
|
||||
feedforward_dim: int,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
dropout: float = 0.1,
|
||||
ignore_id: int = -1,
|
||||
warmup_batches: float = 4000.0,
|
||||
label_smoothing: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.eos_id = eos_id
|
||||
self.sos_id = sos_id
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
# For the segment of the warmup period, we let the Embedding
|
||||
# layer learn something. Then we start to warm up the other encoders.
|
||||
self.decoder = TransformerDecoder(
|
||||
vocab_size,
|
||||
decoder_dim,
|
||||
unmasked_dim,
|
||||
num_decoder_layers,
|
||||
attention_dim,
|
||||
nhead,
|
||||
feedforward_dim,
|
||||
dropout,
|
||||
warmup_begin=warmup_batches * 0.5,
|
||||
warmup_end=warmup_batches * 1.0,
|
||||
)
|
||||
|
||||
# Used to calculate attention-decoder loss
|
||||
self.loss_fun = LabelSmoothingLoss(
|
||||
ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum"
|
||||
)
|
||||
|
||||
def _pre_ys_in_out(self, token_ids: List[List[int]], device: torch.device):
|
||||
"""Prepare ys_in_pad and ys_out_pad."""
|
||||
ys = k2.RaggedTensor(token_ids).to(device=device)
|
||||
row_splits = ys.shape.row_splits(1)
|
||||
ys_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
ys_in = add_sos(ys, sos_id=self.sos_id)
|
||||
# [B, S+1], start with SOS
|
||||
ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id)
|
||||
ys_in_lens = ys_lens + 1
|
||||
|
||||
ys_out = add_eos(ys, eos_id=self.eos_id)
|
||||
# [B, S+1], end with EOS
|
||||
ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id)
|
||||
|
||||
return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64)
|
||||
|
||||
def calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Calculate attention-decoder loss.
|
||||
Args:
|
||||
encoder_out: (batch, num_frames, encoder_dim)
|
||||
encoder_out_lens: (batch,)
|
||||
token_ids: A list of token id list.
|
||||
|
||||
Return: The attention-decoder loss.
|
||||
"""
|
||||
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
|
||||
token_ids, encoder_out.device
|
||||
)
|
||||
|
||||
# decoder forward
|
||||
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
|
||||
|
||||
loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
|
||||
return loss
|
||||
|
||||
def nll(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
) -> torch.Tensor:
|
||||
"""Compute negative log likelihood(nll) from attention-decoder.
|
||||
Args:
|
||||
encoder_out: (batch, num_frames, encoder_dim)
|
||||
encoder_out_lens: (batch,)
|
||||
token_ids: A list of token id list.
|
||||
|
||||
Return: A tensor of shape (batch, num_tokens).
|
||||
"""
|
||||
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
|
||||
token_ids, encoder_out.device
|
||||
)
|
||||
|
||||
# decoder forward
|
||||
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
|
||||
|
||||
batch_size, _, num_classes = decoder_out.size()
|
||||
nll = nn.functional.cross_entropy(
|
||||
decoder_out.view(-1, num_classes),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=self.ignore_id,
|
||||
reduction="none",
|
||||
)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
"""Transfomer decoder module.
|
||||
It is modified from https://github.com/espnet/espnet/blob/master/espnet2/asr/decoder/transformer_decoder.py.
|
||||
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
d_model: decoder dimension
|
||||
num_decoder_layers: number of decoder layers
|
||||
attention_dim: total dimension of multi head attention
|
||||
n_head: number of attention heads
|
||||
feedforward_dim: hidden dimension of feed_forward module
|
||||
dropout: dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
d_model: int,
|
||||
unmasked_dim: int,
|
||||
num_decoder_layers: int,
|
||||
attention_dim: int,
|
||||
nhead: int,
|
||||
feedforward_dim: int,
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.unmasked_dim = unmasked_dim
|
||||
|
||||
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
|
||||
|
||||
# Using absolute positional encoding
|
||||
self.pos = PositionalEncoding(d_model, dropout_rate=0.1)
|
||||
|
||||
self.num_layers = num_decoder_layers
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout)
|
||||
for _ in range(num_decoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.output_layer = nn.Linear(d_model, vocab_size)
|
||||
|
||||
# will be written to, see set_batch_count() Note: in inference time this
|
||||
# may be zero but should be treated as large, we can check if
|
||||
# self.training is true.
|
||||
self.batch_count = 0
|
||||
assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end)
|
||||
self.warmup_begin = warmup_begin
|
||||
self.warmup_end = warmup_end
|
||||
# module_seed is for when we need a random number that is unique to the module but
|
||||
# shared across jobs. It's used to randomly select how many layers to drop,
|
||||
# so that we can keep this consistent across worker tasks (for efficiency).
|
||||
self.module_seed = torch.randint(0, 1000, ()).item()
|
||||
|
||||
delta = (1.0 / num_decoder_layers) * (warmup_end - warmup_begin)
|
||||
cur_begin = warmup_begin
|
||||
for i in range(num_decoder_layers):
|
||||
self.layers[i].warmup_begin = cur_begin
|
||||
cur_begin += delta
|
||||
self.layers[i].warmup_end = cur_begin
|
||||
|
||||
def get_layers_to_drop(self, rnd_seed: int):
|
||||
ans = set()
|
||||
if not self.training:
|
||||
return ans
|
||||
|
||||
batch_count = self.batch_count
|
||||
num_layers = len(self.layers)
|
||||
|
||||
def get_layerdrop_prob(layer: int) -> float:
|
||||
layer_warmup_begin = self.layers[layer].warmup_begin
|
||||
layer_warmup_end = self.layers[layer].warmup_end
|
||||
|
||||
initial_layerdrop_prob = 0.5
|
||||
final_layerdrop_prob = 0.05
|
||||
|
||||
if batch_count == 0:
|
||||
# As a special case, if batch_count == 0, return 0 (drop no
|
||||
# layers). This is rather ugly, I'm afraid; it is intended to
|
||||
# enable our scan_pessimistic_batches_for_oom() code to work correctly
|
||||
# so if we are going to get OOM it will happen early.
|
||||
# also search for 'batch_count' with quotes in this file to see
|
||||
# how we initialize the warmup count to a random number between
|
||||
# 0 and 10.
|
||||
return 0.0
|
||||
elif batch_count < layer_warmup_begin:
|
||||
return initial_layerdrop_prob
|
||||
elif batch_count > layer_warmup_end:
|
||||
return final_layerdrop_prob
|
||||
else:
|
||||
# linearly interpolate
|
||||
t = (batch_count - layer_warmup_begin) / layer_warmup_end
|
||||
assert 0.0 <= t < 1.001, t
|
||||
return initial_layerdrop_prob + t * (
|
||||
final_layerdrop_prob - initial_layerdrop_prob
|
||||
)
|
||||
|
||||
shared_rng = random.Random(batch_count + self.module_seed)
|
||||
independent_rng = random.Random(rnd_seed)
|
||||
|
||||
layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)]
|
||||
tot = sum(layerdrop_probs)
|
||||
# Instead of drawing the samples independently, we first randomly decide
|
||||
# how many layers to drop out, using the same random number generator between
|
||||
# jobs so that all jobs drop out the same number (this is for speed).
|
||||
# Then we use an approximate approach to drop out the individual layers
|
||||
# with their specified probs while reaching this exact target.
|
||||
num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot)))
|
||||
|
||||
layers = list(range(num_layers))
|
||||
independent_rng.shuffle(layers)
|
||||
|
||||
# go through the shuffled layers until we get the required number of samples.
|
||||
if num_to_drop > 0:
|
||||
for layer in itertools.cycle(layers):
|
||||
if independent_rng.random() < layerdrop_probs[layer]:
|
||||
ans.add(layer)
|
||||
if len(ans) == num_to_drop:
|
||||
break
|
||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(
|
||||
f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
|
||||
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}"
|
||||
)
|
||||
return ans
|
||||
|
||||
def get_feature_mask(self, x: torch.Tensor) -> float:
|
||||
# Note: The actual return type is Union[List[float], List[Tensor]],
|
||||
# but to make torch.jit.script() work, we use List[float]
|
||||
"""
|
||||
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
|
||||
randomized feature masks, one per encoder.
|
||||
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
||||
some supplied number, e.g. >256, so in effect on those frames we are using
|
||||
a smaller encoer dim.
|
||||
|
||||
We generate the random masks at this level because we want the 2 masks to 'agree'
|
||||
all the way up the encoder stack. This will mean that the 1st mask will have
|
||||
mask values repeated self.zipformer_subsampling_factor times.
|
||||
|
||||
Args:
|
||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||
(num_frames, batch_size, encoder_dims0)
|
||||
"""
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
return 1.0
|
||||
|
||||
batch_size, num_frames, d_model = x.size()
|
||||
|
||||
feature_mask_dropout_prob = 0.15
|
||||
frame_mask = (
|
||||
torch.rand(batch_size, num_frames, 1, device=x.device)
|
||||
> feature_mask_dropout_prob
|
||||
).to(x.dtype)
|
||||
|
||||
feature_mask = torch.ones(
|
||||
batch_size, num_frames, d_model, dtype=x.dtype, device=x.device
|
||||
)
|
||||
feature_mask[:, :, self.unmasked_dim :] *= frame_mask
|
||||
|
||||
return feature_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
|
||||
Args:
|
||||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
hlens: (batch)
|
||||
ys_in_pad:
|
||||
input token ids, int64 (batch, maxlen_out)
|
||||
if input_layer == "embed"
|
||||
input tensor (batch, maxlen_out, #mels) in the other cases
|
||||
ys_in_lens: (batch)
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
|
||||
x: decoded token score before softmax (batch, maxlen_out, token)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask | (~m)
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = make_pad_mask(hlens)[:, None, :].to(memory.device)
|
||||
|
||||
tgt = self.embed(tgt)
|
||||
tgt = self.pos(tgt)
|
||||
|
||||
rnd_seed = tgt.numel() + random.randint(0, 1000)
|
||||
layers_to_drop = self.get_layers_to_drop(rnd_seed)
|
||||
|
||||
feature_mask = self.get_feature_mask(tgt)
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
if i in layers_to_drop:
|
||||
continue
|
||||
tgt = mod(tgt, tgt_mask, memory, memory_mask)
|
||||
tgt = tgt * feature_mask
|
||||
|
||||
tgt = self.output_layer(tgt)
|
||||
return tgt
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
d_model: equal to encoder_dim
|
||||
attention_dim: total dimension of multi head attention
|
||||
n_head: number of attention heads
|
||||
feedforward_dim: hidden dimension of feed_forward module
|
||||
dropout: dropout rate
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
attention_dim: int,
|
||||
nhead: int,
|
||||
feedforward_dim: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
# will be written to, see set_batch_count()
|
||||
self.batch_count = 0
|
||||
|
||||
self.self_attn = MultiHeadedAttention(
|
||||
d_model, attention_dim, nhead, dropout=0.0
|
||||
)
|
||||
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
|
||||
self.feed_forward = FeedforwardModule(d_model, feedforward_dim, dropout)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
d_model,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.whiten = Whiten(
|
||||
num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01
|
||||
)
|
||||
|
||||
def get_bypass_scale(self):
|
||||
if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
return self.bypass_scale
|
||||
if random.random() < 0.1:
|
||||
# ensure we get grads if self.bypass_scale becomes out of range
|
||||
return self.bypass_scale
|
||||
# hardcode warmup period for bypass scale
|
||||
warmup_period = 20000.0
|
||||
initial_clamp_min = 0.75
|
||||
final_clamp_min = 0.25
|
||||
if self.batch_count > warmup_period:
|
||||
clamp_min = final_clamp_min
|
||||
else:
|
||||
clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * (
|
||||
initial_clamp_min - final_clamp_min
|
||||
)
|
||||
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
|
||||
|
||||
def get_dynamic_dropout_rate(self):
|
||||
# return dropout rate for the dynamic modules (self_attn, src_attn, feed_forward); this
|
||||
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
|
||||
# at the beginning, by making the network focus on the feedforward modules.
|
||||
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
||||
return 0.0
|
||||
warmup_period = 2000.0
|
||||
initial_dropout_rate = 0.2
|
||||
final_dropout_rate = 0.0
|
||||
if self.batch_count > warmup_period:
|
||||
return final_dropout_rate
|
||||
else:
|
||||
return initial_dropout_rate - (
|
||||
initial_dropout_rate * final_dropout_rate
|
||||
) * (self.batch_count / warmup_period)
|
||||
|
||||
def forward(self, tgt, tgt_mask, memory, memory_mask):
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor(#batch, maxlen_out, size).
|
||||
"""
|
||||
tgt_orig = tgt
|
||||
|
||||
# dropout rate for submodules that interact with time.
|
||||
dynamic_dropout = self.get_dynamic_dropout_rate()
|
||||
|
||||
# self-attn module
|
||||
if random.random() >= dynamic_dropout:
|
||||
tgt = tgt + self.self_attn(tgt, tgt, tgt, tgt_mask)
|
||||
|
||||
# cross-attn module
|
||||
if random.random() >= dynamic_dropout:
|
||||
tgt = tgt + self.src_attn(tgt, memory, memory, memory_mask)
|
||||
|
||||
# feed-forward module
|
||||
tgt = tgt + self.feed_forward(tgt)
|
||||
|
||||
tgt = self.norm_final(self.balancer(tgt))
|
||||
|
||||
delta = tgt - tgt_orig
|
||||
tgt = tgt_orig + delta * self.get_bypass_scale()
|
||||
|
||||
return self.whiten(tgt)
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
attention_dim: dimension in the attention module, may be less or more than embed_dim
|
||||
but must be a multiple of num_heads.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0
|
||||
):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.attention_dim = attention_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = attention_dim // num_heads
|
||||
assert self.head_dim % 2 == 0, self.head_dim
|
||||
assert self.head_dim * num_heads == attention_dim, (
|
||||
self.head_dim,
|
||||
num_heads,
|
||||
attention_dim,
|
||||
)
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5, dividing it between the query and key.
|
||||
self.linear_q = ScaledLinear(
|
||||
embed_dim, attention_dim, bias=True, initial_scale=self.head_dim**-0.25
|
||||
)
|
||||
self.linear_k = ScaledLinear(
|
||||
embed_dim, attention_dim, bias=True, initial_scale=self.head_dim**-0.25
|
||||
)
|
||||
self.linear_v = ScaledLinear(
|
||||
embed_dim,
|
||||
attention_dim // 2,
|
||||
bias=True,
|
||||
initial_scale=self.head_dim**-0.25,
|
||||
)
|
||||
|
||||
# self.whiten_v is applied on the values in forward();
|
||||
# it just copies the keys but prevents low-rank distribution by modifying grads.
|
||||
self.whiten_v = Whiten(
|
||||
num_groups=num_heads,
|
||||
whitening_limit=2.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025,
|
||||
)
|
||||
self.whiten_k = Whiten(
|
||||
num_groups=num_heads,
|
||||
whitening_limit=2.0,
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025,
|
||||
)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
|
||||
)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
bsz, tgt_len, _ = query.size()
|
||||
src_len = key.size(1)
|
||||
num_heads = self.num_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
q = self.linear_q(query)
|
||||
k = self.linear_k(key)
|
||||
v = self.linear_v(value)
|
||||
|
||||
k = self.whiten_k(k) # does nothing in the forward pass.
|
||||
v = self.whiten_v(v) # does nothing in the forward pass.
|
||||
|
||||
q = q.reshape(bsz, tgt_len, num_heads, head_dim)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, head_dim)
|
||||
k = k.reshape(bsz, src_len, num_heads, head_dim)
|
||||
k = k.permute(0, 2, 3, 1) # (batch, head, head_dim, time2)
|
||||
v = v.reshape(bsz, src_len, num_heads, head_dim // 2)
|
||||
v = v.transpose(1, 2).reshape(bsz * num_heads, src_len, head_dim // 2)
|
||||
|
||||
# (batch, head, time1, time2)
|
||||
attn_output_weights = torch.matmul(q, k)
|
||||
|
||||
# This is a harder way of limiting the attention scores to not be too large.
|
||||
# It incurs a penalty if any of them has an absolute value greater than 50.0.
|
||||
# this should be outside the normal range of the attention scores. We use
|
||||
# this mechanism instead of, say, a limit on entropy, because once the entropy
|
||||
# gets very small gradients through the softmax can become very small, and
|
||||
# some mechanisms like that become ineffective.
|
||||
attn_output_weights = penalize_abs_values_gt(
|
||||
attn_output_weights, limit=25.0, penalty=1.0e-04
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
mask.unsqueeze(1), float("-inf")
|
||||
)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
# Using this version of softmax, defined in scaling.py,
|
||||
# should save a little of the memory used in backprop by, if
|
||||
# we are in automatic mixed precision mode (amp) == autocast,
|
||||
# only storing the half-precision output for backprop purposes.
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
# (bsz * head, time1, head_dim_v)
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert attn_output.shape == (bsz * num_heads, tgt_len, head_dim // 2)
|
||||
attn_output = (
|
||||
attn_output.reshape(bsz, num_heads, tgt_len, head_dim // 2)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, tgt_len, self.attention_dim // 2)
|
||||
)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Positional encoding.
|
||||
Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
def subsequent_mask(size, device="cpu", dtype=torch.bool):
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
:param int size: size of mask
|
||||
:param str device: "cpu" or "cuda" or torch.Tensor.device
|
||||
:param torch.dtype dtype: result dtype
|
||||
:rtype: torch.Tensor
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=dtype)
|
||||
return torch.tril(ret, out=ret)
|
||||
|
||||
|
||||
def _test_attention_decoder_model():
|
||||
m = AttentionDecoderModel(
|
||||
vocab_size=500,
|
||||
decoder_dim=384,
|
||||
unmasked_dim=256,
|
||||
num_decoder_layers=6,
|
||||
attention_dim=192,
|
||||
nhead=8,
|
||||
feedforward_dim=2048,
|
||||
dropout=0.1,
|
||||
sos_id=1,
|
||||
eos_id=1,
|
||||
ignore_id=-1,
|
||||
)
|
||||
m.eval()
|
||||
encoder_out = torch.randn(2, 50, 384)
|
||||
encoder_out_lens = torch.full((2,), 50)
|
||||
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
|
||||
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
|
||||
print(loss)
|
||||
|
||||
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
|
||||
print(nll)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_attention_decoder_model()
|
841
egs/librispeech/ASR/zipformer_ctc_attn/decode.py
Executable file
841
egs/librispeech/ASR/zipformer_ctc_attn/decode.py
Executable file
@ -0,0 +1,841 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Liyong Guo,
|
||||
# Quandong Wang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) ctc-decoding
|
||||
./zipformer_ctc_attn/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_ctc_attn/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-decoding
|
||||
|
||||
(2) 1best
|
||||
./zipformer_ctc_attn/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_ctc_attn/exp \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--decoding-method 1best
|
||||
|
||||
(3) nbest
|
||||
./zipformer_ctc_attn/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_ctc_attn/exp \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--decoding-method nbest
|
||||
|
||||
(4) nbest-rescoring
|
||||
./zipformer_ctc_attn/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_ctc_attn/exp \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method nbest-rescoring
|
||||
|
||||
(5) whole-lattice-rescoring
|
||||
./zipformer_ctc_attn/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer_ctc_attn/exp \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method whole-lattice-rescoring
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_ctc_attention_model, get_params
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder2,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer_ctc_attn/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="ctc-decoding",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hlg-scale",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="""The scale to be applied to `hlg.scores`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_decoding_params() -> AttributeDict:
|
||||
"""Parameters for decoding."""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
|
||||
nnet_output = model.ctc_output(encoder_out)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method in ["1best", "nbest"]:
|
||||
if params.decoding_method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.decoding_method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.decoding_method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.decoding_method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder2(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
attention_decoder=model.decoder,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
ans = None
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-decoding",
|
||||
"1best",
|
||||
"nbest",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"nbest-oracle",
|
||||
"attention-decoder",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
assert "lang_bpe" in str(params.lang_dir), "Currently only supports bpe model."
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
# sos_id, eos_id, ignore_id will be used in AttentionDecoderModel
|
||||
params.sos_id = graph_compiler.sos_id
|
||||
params.eos_id = graph_compiler.eos_id
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.vocab_size = graph_compiler.sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
HLG.scores *= params.hlg_scale
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.decoding_method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set G.properties to None
|
||||
G.__dict__["_properties"] = None
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_ctc_attention_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
num_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||
logging.info(f"Number of parameters in attention decoder: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/ASR/zipformer_ctc_attn/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
1
egs/librispeech/ASR/zipformer_ctc_attn/label_smoothing.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/label_smoothing.py
Symbolic link
@ -0,0 +1 @@
|
||||
../conformer_ctc/label_smoothing.py
|
95
egs/librispeech/ASR/zipformer_ctc_attn/model.py
Normal file
95
egs/librispeech/ASR/zipformer_ctc_attn/model.py
Normal file
@ -0,0 +1,95 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
|
||||
class CTCAttentionModel(nn.Module):
|
||||
"""Hybrid CTC & Attention decoder model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
encoder_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the Zipformer encoder model. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the attention decoder.
|
||||
encoder_dim:
|
||||
The embedding dimension of encoder.
|
||||
vocab_size:
|
||||
The vocabulary size.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder = encoder
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
# Attention decoder
|
||||
self.decoder = decoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
token_ids:
|
||||
A list of token id list.
|
||||
|
||||
Returns:
|
||||
- ctc_output, ctc log-probs
|
||||
- att_loss, attention decoder loss
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert x.size(0) == x_lens.size(0) == len(token_ids)
|
||||
|
||||
# encoder forward
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# compute attention decoder loss
|
||||
att_loss = self.decoder.calc_att_loss(encoder_out, x_lens, token_ids)
|
||||
|
||||
return ctc_output, att_loss
|
1
egs/librispeech/ASR/zipformer_ctc_attn/optim.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/optim.py
|
1
egs/librispeech/ASR/zipformer_ctc_attn/scaling.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling.py
|
1
egs/librispeech/ASR/zipformer_ctc_attn/scaling_converter.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling_converter.py
|
1269
egs/librispeech/ASR/zipformer_ctc_attn/train.py
Executable file
1269
egs/librispeech/ASR/zipformer_ctc_attn/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/ASR/zipformer_ctc_attn/zipformer.py
Symbolic link
1
egs/librispeech/ASR/zipformer_ctc_attn/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/zipformer.py
|
@ -1083,6 +1083,143 @@ def rescore_with_attention_decoder(
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder2(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
attention_decoder: torch.nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
nbest_scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
attention_decoder:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `(T, N, C)`.
|
||||
nbest_scale:
|
||||
It's the scale applied to `lattice.scores`. A smaller value
|
||||
leads to more unique paths at the risk of missing the correct path.
|
||||
ngram_lm_scale:
|
||||
Optional. It specifies the scale for n-gram LM scores.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
max_loop_count = 10
|
||||
loop_count = 0
|
||||
while loop_count <= max_loop_count:
|
||||
try:
|
||||
nbest = Nbest.from_lattice(
|
||||
lattice=lattice,
|
||||
num_paths=num_paths,
|
||||
use_double_scores=use_double_scores,
|
||||
nbest_scale=nbest_scale,
|
||||
)
|
||||
# nbest.fsa.scores are all 0s at this point
|
||||
nbest = nbest.intersect(lattice)
|
||||
break
|
||||
except RuntimeError as e:
|
||||
logging.info(f"Caught exception:\n{e}\n")
|
||||
logging.info(f"num_paths before decreasing: {num_paths}")
|
||||
num_paths = int(num_paths / 2)
|
||||
if loop_count >= max_loop_count or num_paths <= 0:
|
||||
logging.info("Return None as the resulting lattice is too large.")
|
||||
return None
|
||||
logging.info(
|
||||
"This OOM is not an error. You can ignore it. "
|
||||
"If your model does not converge well, or --max-duration "
|
||||
"is too large, or the input sound file is difficult to "
|
||||
"decode, you will meet this exception."
|
||||
)
|
||||
logging.info(f"num_paths after decreasing: {num_paths}")
|
||||
loop_count += 1
|
||||
|
||||
# Now nbest.fsa has its scores set.
|
||||
# Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
am_scores = nbest.compute_am_scores()
|
||||
ngram_lm_scores = nbest.compute_lm_scores()
|
||||
|
||||
# The `tokens` attribute is set inside `compile_hlg.py`
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||
|
||||
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (T, N, C), so we use axis=1 here
|
||||
expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
|
||||
expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
|
||||
|
||||
# remove axis corresponding to states.
|
||||
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
tokens = tokens.remove_values_leq(0)
|
||||
token_ids = tokens.tolist()
|
||||
|
||||
if len(token_ids) == 0:
|
||||
print("Warning: rescore_with_attention_decoder(): empty token-ids")
|
||||
return None
|
||||
|
||||
nll = attention_decoder.nll(
|
||||
encoder_out=expanded_encoder_out,
|
||||
encoder_out_lens=expanded_encoder_out_lens,
|
||||
token_ids=token_ids,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
ans = dict()
|
||||
for n_scale in ngram_lm_scale_list:
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = (
|
||||
am_scores.values
|
||||
+ n_scale * ngram_lm_scores.values
|
||||
+ a_scale * attention_scores
|
||||
)
|
||||
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(nbest.fsa, max_indexes)
|
||||
|
||||
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_rnn_lm(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
|
@ -1171,10 +1171,13 @@ class MetricsTracker(collections.defaultdict):
|
||||
|
||||
def __str__(self) -> str:
|
||||
ans_frames = ""
|
||||
ans_symbols = ""
|
||||
ans_utterances = ""
|
||||
for k, v in self.norm_items():
|
||||
norm_value = "%.4g" % v
|
||||
if "utt_" not in k:
|
||||
if k == "att_loss":
|
||||
ans_symbols += str(k) + "=" + str(norm_value) + ", "
|
||||
elif "utt_" not in k:
|
||||
ans_frames += str(k) + "=" + str(norm_value) + ", "
|
||||
else:
|
||||
ans_utterances += str(k) + "=" + str(norm_value)
|
||||
@ -1186,11 +1189,13 @@ class MetricsTracker(collections.defaultdict):
|
||||
raise ValueError(f"Unexpected key: {k}")
|
||||
frames = "%.2f" % self["frames"]
|
||||
ans_frames += "over " + str(frames) + " frames. "
|
||||
symbols = "%.2f" % self["symbols"]
|
||||
ans_symbols += "over " + str(symbols) + " symbols. "
|
||||
if ans_utterances != "":
|
||||
utterances = "%.2f" % self["utterances"]
|
||||
ans_utterances += "over " + str(utterances) + " utterances."
|
||||
|
||||
return ans_frames + ans_utterances
|
||||
return ans_frames + ans_symbols + ans_utterances
|
||||
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
@ -1198,14 +1203,18 @@ class MetricsTracker(collections.defaultdict):
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
"""
|
||||
num_frames = self["frames"] if "frames" in self else 1
|
||||
num_symbols = self["symbols"] if "symbols" in self else 1
|
||||
num_utterances = self["utterances"] if "utterances" in self else 1
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k == "frames" or k == "utterances":
|
||||
if k == "frames" or k == "symbols" or k == "utterances":
|
||||
continue
|
||||
norm_value = (
|
||||
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
|
||||
)
|
||||
if k == "att_loss":
|
||||
norm_value = float(v) / num_symbols
|
||||
elif "utt_" in k:
|
||||
norm_value = float(v) / num_utterances
|
||||
else:
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
return ans
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user