Merge 3364d9863c7816da59b9df30f58784617ab23b84 into ff2bef9e501a4b5ebfec04cbfe8afa2e8bea4b40

This commit is contained in:
Zengwei Yao 2024-06-20 03:39:40 +00:00 committed by GitHub
commit ad4615eba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 3110 additions and 6 deletions

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/asr_datamodule.py

View File

@ -0,0 +1,746 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The model structure is modified from Daniel Povey's Zipformer
# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py
import itertools
import logging
import math
import random
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
Identity,
MaxEig,
ScaledConv1d,
ScaledLinear,
Whiten,
_diag,
penalize_abs_values_gt,
random_clamp,
softmax,
)
from zipformer import FeedforwardModule
from icefall.utils import add_eos, add_sos, make_pad_mask
class AttentionDecoderModel(nn.Module):
"""
Args:
vocab_size (int): Number of classes.
decoder_dim: (int,int): embedding dimension of 2 encoder stacks
attention_dim: (int,int): attention dimension of 2 encoder stacks
nhead (int, int): number of heads
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
warmup_batches (float): number of batches to warm up over
"""
def __init__(
self,
vocab_size: int,
decoder_dim: int,
unmasked_dim: int,
num_decoder_layers: int,
attention_dim: int,
nhead: int,
feedforward_dim: int,
sos_id: int,
eos_id: int,
dropout: float = 0.1,
ignore_id: int = -1,
warmup_batches: float = 4000.0,
label_smoothing: float = 0.1,
):
super().__init__()
self.eos_id = eos_id
self.sos_id = sos_id
self.ignore_id = ignore_id
# For the segment of the warmup period, we let the Embedding
# layer learn something. Then we start to warm up the other encoders.
self.decoder = TransformerDecoder(
vocab_size,
decoder_dim,
unmasked_dim,
num_decoder_layers,
attention_dim,
nhead,
feedforward_dim,
dropout,
warmup_begin=warmup_batches * 0.5,
warmup_end=warmup_batches * 1.0,
)
# Used to calculate attention-decoder loss
self.loss_fun = LabelSmoothingLoss(
ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum"
)
def _pre_ys_in_out(self, token_ids: List[List[int]], device: torch.device):
"""Prepare ys_in_pad and ys_out_pad."""
ys = k2.RaggedTensor(token_ids).to(device=device)
row_splits = ys.shape.row_splits(1)
ys_lens = row_splits[1:] - row_splits[:-1]
ys_in = add_sos(ys, sos_id=self.sos_id)
# [B, S+1], start with SOS
ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id)
ys_in_lens = ys_lens + 1
ys_out = add_eos(ys, eos_id=self.eos_id)
# [B, S+1], end with EOS
ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id)
return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64)
def calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
token_ids: List[List[int]],
) -> torch.Tensor:
"""Calculate attention-decoder loss.
Args:
encoder_out: (batch, num_frames, encoder_dim)
encoder_out_lens: (batch,)
token_ids: A list of token id list.
Return: The attention-decoder loss.
"""
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
token_ids, encoder_out.device
)
# decoder forward
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
loss = self.loss_fun(x=decoder_out, target=ys_out_pad)
return loss
def nll(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
token_ids: List[List[int]],
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from attention-decoder.
Args:
encoder_out: (batch, num_frames, encoder_dim)
encoder_out_lens: (batch,)
token_ids: A list of token id list.
Return: A tensor of shape (batch, num_tokens).
"""
ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(
token_ids, encoder_out.device
)
# decoder forward
decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens)
batch_size, _, num_classes = decoder_out.size()
nll = nn.functional.cross_entropy(
decoder_out.view(-1, num_classes),
ys_out_pad.view(-1),
ignore_index=self.ignore_id,
reduction="none",
)
nll = nll.view(batch_size, -1)
return nll
class TransformerDecoder(nn.Module):
"""Transfomer decoder module.
It is modified from https://github.com/espnet/espnet/blob/master/espnet2/asr/decoder/transformer_decoder.py.
Args:
vocab_size: output dim
d_model: decoder dimension
num_decoder_layers: number of decoder layers
attention_dim: total dimension of multi head attention
n_head: number of attention heads
feedforward_dim: hidden dimension of feed_forward module
dropout: dropout rate
"""
def __init__(
self,
vocab_size: int,
d_model: int,
unmasked_dim: int,
num_decoder_layers: int,
attention_dim: int,
nhead: int,
feedforward_dim: int,
dropout: float,
warmup_begin: float,
warmup_end: float,
):
super().__init__()
self.unmasked_dim = unmasked_dim
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
# Using absolute positional encoding
self.pos = PositionalEncoding(d_model, dropout_rate=0.1)
self.num_layers = num_decoder_layers
self.layers = nn.ModuleList(
[
DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout)
for _ in range(num_decoder_layers)
]
)
self.output_layer = nn.Linear(d_model, vocab_size)
# will be written to, see set_batch_count() Note: in inference time this
# may be zero but should be treated as large, we can check if
# self.training is true.
self.batch_count = 0
assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end)
self.warmup_begin = warmup_begin
self.warmup_end = warmup_end
# module_seed is for when we need a random number that is unique to the module but
# shared across jobs. It's used to randomly select how many layers to drop,
# so that we can keep this consistent across worker tasks (for efficiency).
self.module_seed = torch.randint(0, 1000, ()).item()
delta = (1.0 / num_decoder_layers) * (warmup_end - warmup_begin)
cur_begin = warmup_begin
for i in range(num_decoder_layers):
self.layers[i].warmup_begin = cur_begin
cur_begin += delta
self.layers[i].warmup_end = cur_begin
def get_layers_to_drop(self, rnd_seed: int):
ans = set()
if not self.training:
return ans
batch_count = self.batch_count
num_layers = len(self.layers)
def get_layerdrop_prob(layer: int) -> float:
layer_warmup_begin = self.layers[layer].warmup_begin
layer_warmup_end = self.layers[layer].warmup_end
initial_layerdrop_prob = 0.5
final_layerdrop_prob = 0.05
if batch_count == 0:
# As a special case, if batch_count == 0, return 0 (drop no
# layers). This is rather ugly, I'm afraid; it is intended to
# enable our scan_pessimistic_batches_for_oom() code to work correctly
# so if we are going to get OOM it will happen early.
# also search for 'batch_count' with quotes in this file to see
# how we initialize the warmup count to a random number between
# 0 and 10.
return 0.0
elif batch_count < layer_warmup_begin:
return initial_layerdrop_prob
elif batch_count > layer_warmup_end:
return final_layerdrop_prob
else:
# linearly interpolate
t = (batch_count - layer_warmup_begin) / layer_warmup_end
assert 0.0 <= t < 1.001, t
return initial_layerdrop_prob + t * (
final_layerdrop_prob - initial_layerdrop_prob
)
shared_rng = random.Random(batch_count + self.module_seed)
independent_rng = random.Random(rnd_seed)
layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)]
tot = sum(layerdrop_probs)
# Instead of drawing the samples independently, we first randomly decide
# how many layers to drop out, using the same random number generator between
# jobs so that all jobs drop out the same number (this is for speed).
# Then we use an approximate approach to drop out the individual layers
# with their specified probs while reaching this exact target.
num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot)))
layers = list(range(num_layers))
independent_rng.shuffle(layers)
# go through the shuffled layers until we get the required number of samples.
if num_to_drop > 0:
for layer in itertools.cycle(layers):
if independent_rng.random() < layerdrop_probs[layer]:
ans.add(layer)
if len(ans) == num_to_drop:
break
if shared_rng.random() < 0.005 or __name__ == "__main__":
logging.info(
f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}"
)
return ans
def get_feature_mask(self, x: torch.Tensor) -> float:
# Note: The actual return type is Union[List[float], List[Tensor]],
# but to make torch.jit.script() work, we use List[float]
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(num_frames, batch_size, encoder_dims0)
"""
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return 1.0
batch_size, num_frames, d_model = x.size()
feature_mask_dropout_prob = 0.15
frame_mask = (
torch.rand(batch_size, num_frames, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype)
feature_mask = torch.ones(
batch_size, num_frames, d_model, dtype=x.dtype, device=x.device
)
feature_mask[:, :, self.unmasked_dim :] *= frame_mask
return feature_mask
def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
tgt = ys_in_pad
# tgt_mask: (B, 1, L)
tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device)
# m: (1, L, L)
m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask | (~m)
memory = hs_pad
memory_mask = make_pad_mask(hlens)[:, None, :].to(memory.device)
tgt = self.embed(tgt)
tgt = self.pos(tgt)
rnd_seed = tgt.numel() + random.randint(0, 1000)
layers_to_drop = self.get_layers_to_drop(rnd_seed)
feature_mask = self.get_feature_mask(tgt)
for i, mod in enumerate(self.layers):
if i in layers_to_drop:
continue
tgt = mod(tgt, tgt_mask, memory, memory_mask)
tgt = tgt * feature_mask
tgt = self.output_layer(tgt)
return tgt
class DecoderLayer(nn.Module):
"""Single decoder layer module.
Args:
d_model: equal to encoder_dim
attention_dim: total dimension of multi head attention
n_head: number of attention heads
feedforward_dim: hidden dimension of feed_forward module
dropout: dropout rate
"""
def __init__(
self,
d_model: int,
attention_dim: int,
nhead: int,
feedforward_dim: int = 2048,
dropout: float = 0.1,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
# will be written to, see set_batch_count()
self.batch_count = 0
self.self_attn = MultiHeadedAttention(
d_model, attention_dim, nhead, dropout=0.0
)
self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0)
self.feed_forward = FeedforwardModule(d_model, feedforward_dim, dropout)
self.norm_final = BasicNorm(d_model)
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(
d_model,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.whiten = Whiten(
num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01
)
def get_bypass_scale(self):
if not self.training or torch.jit.is_scripting() or torch.jit.is_tracing():
return self.bypass_scale
if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range
return self.bypass_scale
# hardcode warmup period for bypass scale
warmup_period = 20000.0
initial_clamp_min = 0.75
final_clamp_min = 0.25
if self.batch_count > warmup_period:
clamp_min = final_clamp_min
else:
clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * (
initial_clamp_min - final_clamp_min
)
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
def get_dynamic_dropout_rate(self):
# return dropout rate for the dynamic modules (self_attn, src_attn, feed_forward); this
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
# at the beginning, by making the network focus on the feedforward modules.
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return 0.0
warmup_period = 2000.0
initial_dropout_rate = 0.2
final_dropout_rate = 0.0
if self.batch_count > warmup_period:
return final_dropout_rate
else:
return initial_dropout_rate - (
initial_dropout_rate * final_dropout_rate
) * (self.batch_count / warmup_period)
def forward(self, tgt, tgt_mask, memory, memory_mask):
"""Compute decoded features.
Args:
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
Returns:
torch.Tensor: Output tensor(#batch, maxlen_out, size).
"""
tgt_orig = tgt
# dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate()
# self-attn module
if random.random() >= dynamic_dropout:
tgt = tgt + self.self_attn(tgt, tgt, tgt, tgt_mask)
# cross-attn module
if random.random() >= dynamic_dropout:
tgt = tgt + self.src_attn(tgt, memory, memory, memory_mask)
# feed-forward module
tgt = tgt + self.feed_forward(tgt)
tgt = self.norm_final(self.balancer(tgt))
delta = tgt - tgt_orig
tgt = tgt_orig + delta * self.get_bypass_scale()
return self.whiten(tgt)
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
Args:
embed_dim: total dimension of the model.
attention_dim: dimension in the attention module, may be less or more than embed_dim
but must be a multiple of num_heads.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
"""
def __init__(
self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0
):
"""Construct an MultiHeadedAttention object."""
super(MultiHeadedAttention, self).__init__()
self.embed_dim = embed_dim
self.attention_dim = attention_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = attention_dim // num_heads
assert self.head_dim % 2 == 0, self.head_dim
assert self.head_dim * num_heads == attention_dim, (
self.head_dim,
num_heads,
attention_dim,
)
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5, dividing it between the query and key.
self.linear_q = ScaledLinear(
embed_dim, attention_dim, bias=True, initial_scale=self.head_dim**-0.25
)
self.linear_k = ScaledLinear(
embed_dim, attention_dim, bias=True, initial_scale=self.head_dim**-0.25
)
self.linear_v = ScaledLinear(
embed_dim,
attention_dim // 2,
bias=True,
initial_scale=self.head_dim**-0.25,
)
# self.whiten_v is applied on the values in forward();
# it just copies the keys but prevents low-rank distribution by modifying grads.
self.whiten_v = Whiten(
num_groups=num_heads,
whitening_limit=2.0,
prob=(0.025, 0.25),
grad_scale=0.025,
)
self.whiten_k = Whiten(
num_groups=num_heads,
whitening_limit=2.0,
prob=(0.025, 0.25),
grad_scale=0.025,
)
self.out_proj = ScaledLinear(
attention_dim // 2, embed_dim, bias=True, initial_scale=0.05
)
def forward(self, query, key, value, mask):
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
bsz, tgt_len, _ = query.size()
src_len = key.size(1)
num_heads = self.num_heads
head_dim = self.head_dim
q = self.linear_q(query)
k = self.linear_k(key)
v = self.linear_v(value)
k = self.whiten_k(k) # does nothing in the forward pass.
v = self.whiten_v(v) # does nothing in the forward pass.
q = q.reshape(bsz, tgt_len, num_heads, head_dim)
q = q.transpose(1, 2) # (batch, head, time1, head_dim)
k = k.reshape(bsz, src_len, num_heads, head_dim)
k = k.permute(0, 2, 3, 1) # (batch, head, head_dim, time2)
v = v.reshape(bsz, src_len, num_heads, head_dim // 2)
v = v.transpose(1, 2).reshape(bsz * num_heads, src_len, head_dim // 2)
# (batch, head, time1, time2)
attn_output_weights = torch.matmul(q, k)
# This is a harder way of limiting the attention scores to not be too large.
# It incurs a penalty if any of them has an absolute value greater than 50.0.
# this should be outside the normal range of the attention scores. We use
# this mechanism instead of, say, a limit on entropy, because once the entropy
# gets very small gradients through the softmax can become very small, and
# some mechanisms like that become ineffective.
attn_output_weights = penalize_abs_values_gt(
attn_output_weights, limit=25.0, penalty=1.0e-04
)
if mask is not None:
attn_output_weights = attn_output_weights.masked_fill(
mask.unsqueeze(1), float("-inf")
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
# Using this version of softmax, defined in scaling.py,
# should save a little of the memory used in backprop by, if
# we are in automatic mixed precision mode (amp) == autocast,
# only storing the half-precision output for backprop purposes.
attn_output_weights = softmax(attn_output_weights, dim=-1)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=self.dropout, training=self.training
)
# (bsz * head, time1, head_dim_v)
attn_output = torch.bmm(attn_output_weights, v)
assert attn_output.shape == (bsz * num_heads, tgt_len, head_dim // 2)
attn_output = (
attn_output.reshape(bsz, num_heads, tgt_len, head_dim // 2)
.transpose(1, 2)
.reshape(bsz, tgt_len, self.attention_dim // 2)
)
attn_output = self.out_proj(attn_output)
return attn_output
class PositionalEncoding(nn.Module):
"""Positional encoding.
Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35.
Args:
d_model (int): Embedding dimension.
dropout_rate (float): Dropout rate.
max_len (int): Maximum input length.
"""
def __init__(self, d_model, dropout_rate, max_len=5000):
"""Construct an PositionalEncoding object."""
super(PositionalEncoding, self).__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = torch.nn.Dropout(p=dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
if self.pe.dtype != x.dtype or self.pe.device != x.device:
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1)]
return self.dropout(x)
def subsequent_mask(size, device="cpu", dtype=torch.bool):
"""Create mask for subsequent steps (size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=dtype)
return torch.tril(ret, out=ret)
def _test_attention_decoder_model():
m = AttentionDecoderModel(
vocab_size=500,
decoder_dim=384,
unmasked_dim=256,
num_decoder_layers=6,
attention_dim=192,
nhead=8,
feedforward_dim=2048,
dropout=0.1,
sos_id=1,
eos_id=1,
ignore_id=-1,
)
m.eval()
encoder_out = torch.randn(2, 50, 384)
encoder_out_lens = torch.full((2,), 50)
token_ids = [[1, 2, 3, 4], [2, 3, 10]]
loss = m.calc_att_loss(encoder_out, encoder_out_lens, token_ids)
print(loss)
nll = m.nll(encoder_out, encoder_out_lens, token_ids)
print(nll)
if __name__ == "__main__":
_test_attention_decoder_model()

View File

@ -0,0 +1,841 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Quandong Wang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) ctc-decoding
./zipformer_ctc_attn/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_attn/exp \
--max-duration 600 \
--decoding-method ctc-decoding
(2) 1best
./zipformer_ctc_attn/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_attn/exp \
--max-duration 600 \
--hlg-scale 0.8 \
--decoding-method 1best
(3) nbest
./zipformer_ctc_attn/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_attn/exp \
--max-duration 600 \
--hlg-scale 0.8 \
--decoding-method nbest
(4) nbest-rescoring
./zipformer_ctc_attn/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_attn/exp \
--max-duration 600 \
--hlg-scale 0.8 \
--lm-dir data/lm \
--decoding-method nbest-rescoring
(5) whole-lattice-rescoring
./zipformer_ctc_attn/ctc_decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./zipformer_ctc_attn/exp \
--max-duration 600 \
--hlg-scale 0.8 \
--lm-dir data/lm \
--decoding-method whole-lattice-rescoring
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_ctc_attention_model, get_params
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder2,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer_ctc_attn/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="ctc-decoding",
help="""Decoding method.
Supported values are:
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (2) 1best. Extract the best path from the decoding lattice as the
decoding result.
- (3) nbest. Extract n paths from the decoding lattice; the path
with the highest score is the decoding result.
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
the highest score is the decoding result.
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
is the decoding result.
you have trained an RNN LM using ./rnn_lm/train.py
- (6) nbest-oracle. Its WER is the lower bound of any n-best
rescoring method can achieve. Useful for debugging n-best
rescoring method.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""Number of paths for n-best based decoding method.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values:
nbest, nbest-rescoring, and nbest-oracle
A smaller value results in more unique paths.
""",
)
parser.add_argument(
"--hlg-scale",
type=float,
default=0.8,
help="""The scale to be applied to `hlg.scores`.
""",
)
parser.add_argument(
"--lm-dir",
type=str,
default="data/lm",
help="""The n-gram LM dir.
It should contain either G_4_gram.pt or G_4_gram.fst.txt
""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"frame_shift_ms": 10,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.decoding_method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None.
"""
if HLG is not None:
device = HLG.device
else:
device = H.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
nnet_output = model.ctc_output(encoder_out)
# nnet_output is (N, T, C)
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=decoding_graph,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.decoding_method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.decoding_method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is only slightly worse than that of rescored lattices.
best_path = nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
word_table=word_table,
nbest_scale=params.nbest_scale,
oov="<UNK>",
)
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps}
if params.decoding_method in ["1best", "nbest"]:
if params.decoding_method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
nbest_scale=params.nbest_scale,
)
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.decoding_method in [
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.decoding_method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
nbest_scale=params.nbest_scale,
)
elif params.decoding_method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=lm_scale_list,
)
elif params.decoding_method == "attention-decoder":
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=None,
)
best_path_dict = rescore_with_attention_decoder2(
lattice=rescored_lattice,
num_paths=params.num_paths,
attention_decoder=model.decoder,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
nbest_scale=params.nbest_scale,
)
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
ans = dict()
if best_path_dict is not None:
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
else:
ans = None
return ans
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.decoding_method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.decoding_method is ctc-decoding.
word_table:
It is the word symbol table.
G:
An LM. It is not None when params.decoding_method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
G=G,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.lm_dir = Path(args.lm_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (
"ctc-decoding",
"1best",
"nbest",
"nbest-rescoring",
"whole-lattice-rescoring",
"nbest-oracle",
"attention-decoder",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
assert "lang_bpe" in str(params.lang_dir), "Currently only supports bpe model."
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
# sos_id, eos_id, ignore_id will be used in AttentionDecoderModel
params.sos_id = graph_compiler.sos_id
params.eos_id = graph_compiler.eos_id
# <blk> is defined in local/train_bpe_model.py
params.vocab_size = graph_compiler.sp.get_piece_size()
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
if params.decoding_method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
)
assert HLG.requires_grad is False
HLG.scores *= params.hlg_scale
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.decoding_method in (
"nbest-rescoring",
"whole-lattice-rescoring",
"attention-decoder",
):
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
# See https://github.com/k2-fsa/k2/issues/874
# for why we need to set G.properties to None
G.__dict__["_properties"] = None
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
# Save a dummy value so that it can be loaded in C++.
# See https://github.com/pytorch/pytorch/issues/67902
# for why we need to do this.
G.dummy = 1
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
G = k2.Fsa.from_dict(d)
if params.decoding_method in ["whole-lattice-rescoring", "attention-decoder"]:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
logging.info("About to create model")
model = get_ctc_attention_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
num_param = sum([p.numel() for p in model.decoder.parameters()])
logging.info(f"Number of parameters in attention decoder: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table,
G=G,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/encoder_interface.py

View File

@ -0,0 +1 @@
../conformer_ctc/label_smoothing.py

View File

@ -0,0 +1,95 @@
# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
class CTCAttentionModel(nn.Module):
"""Hybrid CTC & Attention decoder model."""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
encoder_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
It is the Zipformer encoder model. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the attention decoder.
encoder_dim:
The embedding dimension of encoder.
vocab_size:
The vocabulary size.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder = encoder
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
# Attention decoder
self.decoder = decoder
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
token_ids: List[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
token_ids:
A list of token id list.
Returns:
- ctc_output, ctc log-probs
- att_loss, attention decoder loss
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert x.size(0) == x_lens.size(0) == len(token_ids)
# encoder forward
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# compute ctc log-probs
ctc_output = self.ctc_output(encoder_out)
# compute attention decoder loss
att_loss = self.decoder.calc_att_loss(encoder_out, x_lens, token_ids)
return ctc_output, att_loss

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/scaling_converter.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../pruned_transducer_stateless7/zipformer.py

View File

@ -1083,6 +1083,143 @@ def rescore_with_attention_decoder(
return ans
def rescore_with_attention_decoder2(
lattice: k2.Fsa,
num_paths: int,
attention_decoder: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
nbest_scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
attention_decoder:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
max_loop_count = 10
loop_count = 0
while loop_count <= max_loop_count:
try:
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
# nbest.fsa.scores are all 0s at this point
nbest = nbest.intersect(lattice)
break
except RuntimeError as e:
logging.info(f"Caught exception:\n{e}\n")
logging.info(f"num_paths before decreasing: {num_paths}")
num_paths = int(num_paths / 2)
if loop_count >= max_loop_count or num_paths <= 0:
logging.info("Return None as the resulting lattice is too large.")
return None
logging.info(
"This OOM is not an error. You can ignore it. "
"If your model does not converge well, or --max-duration "
"is too large, or the input sound file is difficult to "
"decode, you will meet this exception."
)
logging.info(f"num_paths after decreasing: {num_paths}")
loop_count += 1
# Now nbest.fsa has its scores set.
# Also, nbest.fsa inherits the attributes from `lattice`.
assert hasattr(nbest.fsa, "lm_scores")
am_scores = nbest.compute_am_scores()
ngram_lm_scores = nbest.compute_lm_scores()
# The `tokens` attribute is set inside `compile_hlg.py`
assert hasattr(nbest.fsa, "tokens")
assert isinstance(nbest.fsa.tokens, torch.Tensor)
path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# the shape of memory is (T, N, C), so we use axis=1 here
expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
# remove axis corresponding to states.
tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
token_ids = tokens.tolist()
if len(token_ids) == 0:
print("Warning: rescore_with_attention_decoder(): empty token-ids")
return None
nll = attention_decoder.nll(
encoder_out=expanded_encoder_out,
encoder_out_lens=expanded_encoder_out_lens,
token_ids=token_ids,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
for n_scale in ngram_lm_scale_list:
for a_scale in attention_scale_list:
tot_scores = (
am_scores.values
+ n_scale * ngram_lm_scores.values
+ a_scale * attention_scores
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path
return ans
def rescore_with_rnn_lm(
lattice: k2.Fsa,
num_paths: int,

View File

@ -1171,10 +1171,13 @@ class MetricsTracker(collections.defaultdict):
def __str__(self) -> str:
ans_frames = ""
ans_symbols = ""
ans_utterances = ""
for k, v in self.norm_items():
norm_value = "%.4g" % v
if "utt_" not in k:
if k == "att_loss":
ans_symbols += str(k) + "=" + str(norm_value) + ", "
elif "utt_" not in k:
ans_frames += str(k) + "=" + str(norm_value) + ", "
else:
ans_utterances += str(k) + "=" + str(norm_value)
@ -1186,11 +1189,13 @@ class MetricsTracker(collections.defaultdict):
raise ValueError(f"Unexpected key: {k}")
frames = "%.2f" % self["frames"]
ans_frames += "over " + str(frames) + " frames. "
symbols = "%.2f" % self["symbols"]
ans_symbols += "over " + str(symbols) + " symbols. "
if ans_utterances != "":
utterances = "%.2f" % self["utterances"]
ans_utterances += "over " + str(utterances) + " utterances."
return ans_frames + ans_utterances
return ans_frames + ans_symbols + ans_utterances
def norm_items(self) -> List[Tuple[str, float]]:
"""
@ -1198,14 +1203,18 @@ class MetricsTracker(collections.defaultdict):
[('ctc_loss', 0.1), ('att_loss', 0.07)]
"""
num_frames = self["frames"] if "frames" in self else 1
num_symbols = self["symbols"] if "symbols" in self else 1
num_utterances = self["utterances"] if "utterances" in self else 1
ans = []
for k, v in self.items():
if k == "frames" or k == "utterances":
if k == "frames" or k == "symbols" or k == "utterances":
continue
norm_value = (
float(v) / num_frames if "utt_" not in k else float(v) / num_utterances
)
if k == "att_loss":
norm_value = float(v) / num_symbols
elif "utt_" in k:
norm_value = float(v) / num_utterances
else:
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans