Desh Raj 7d56685734
[recipe] LibriSpeech zipformer_ctc (#941)
* merge upstream

* initial commit for zipformer_ctc

* remove unwanted changes

* remove changes to other recipe

* fix zipformer softlink

* fix for JIT export

* add missing file

* fix symbolic links

* update results

* Update RESULTS.md

Address comments from @csukuangfj

---------

Co-authored-by: zr_jin <peter.jin.cn@gmail.com>
2023-10-27 13:38:09 +08:00

159 lines
5.6 KiB
Python

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from transformer import encoder_padding_mask
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.utils import encode_supervisions
class CTCModel(nn.Module):
"""It implements a CTC model with an auxiliary attention head."""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
encoder_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
An instance of `EncoderInterface`. The shared encoder for the CTC and attention
branches
decoder:
An instance of `nn.Module`. This is the decoder for the attention branch.
encoder_dim:
Dimension of the encoder output.
decoder_dim:
Dimension of the decoder output.
vocab_size:
Number of tokens of the modeling unit including blank.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder = encoder
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
self.decoder = decoder
@torch.jit.ignore
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
supervisions: torch.Tensor,
graph_compiler: BpeCtcTrainingGraphCompiler,
subsampling_factor: int = 1,
beam_size: int = 10,
reduction: str = "sum",
use_double_scores: bool = False,
) -> torch.Tensor:
"""
Args:
x:
Tensor of dimension (N, T, C) where N is the batch size,
T is the number of frames, and C is the feature dimension.
x_lens:
Tensor of dimension (N,) where N is the batch size.
supervisions:
Supervisions are used in training.
graph_compiler:
It is used to compile a decoding graph from texts.
subsampling_factor:
It is used to compute the `supervisions` for the encoder.
beam_size:
Beam size used in `k2.ctc_loss`.
reduction:
Reduction method used in `k2.ctc_loss`.
use_double_scores:
If True, use double precision in `k2.ctc_loss`.
Returns:
Return the CTC loss, attention loss, and the total number of frames.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
nnet_output, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# compute ctc log-probs
ctc_output = self.ctc_output(nnet_output)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=subsampling_factor
)
num_frames = supervision_segments[:, 2].sum().item()
# Works with a BPE model
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments.cpu(),
allow_truncate=subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=beam_size,
reduction=reduction,
use_double_scores=use_double_scores,
)
if self.decoder is not None:
nnet_output = nnet_output.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mmodel = (
self.decoder.module if hasattr(self.decoder, "module") else self.decoder
)
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"])
mask = encoder_padding_mask(nnet_output.size(0), supervisions)
mask = mask.to(nnet_output.device) if mask is not None else None
att_loss = mmodel.forward(
nnet_output,
mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = torch.tensor([0])
return ctc_loss, att_loss, num_frames