mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* Replace deprecated pytorch methods - torch.cuda.amp.GradScaler(...) => torch.amp.GradScaler("cuda", ...) - torch.cuda.amp.autocast(...) => torch.amp.autocast("cuda", ...) * Replace `with autocast(...)` with `with autocast("cuda", ...)` Co-authored-by: Li Peng <lipeng@unisound.ai>
260 lines
9.4 KiB
Python
260 lines
9.4 KiB
Python
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
from typing import Tuple
|
|
|
|
import k2
|
|
import torch
|
|
import torch.nn as nn
|
|
from encoder_interface import EncoderInterface
|
|
from scaling import ScaledLinear
|
|
|
|
from icefall.utils import add_sos
|
|
|
|
|
|
class Transducer(nn.Module):
|
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
|
"Sequence Transduction with Recurrent Neural Networks"
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
encoder: EncoderInterface,
|
|
decoder: nn.Module,
|
|
joiner: nn.Module,
|
|
encoder_dim: int,
|
|
decoder_dim: int,
|
|
joiner_dim: int,
|
|
vocab_size: int,
|
|
num_codebooks: int = 0,
|
|
):
|
|
"""
|
|
Args:
|
|
encoder:
|
|
It is the transcription network in the paper. Its accepts
|
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
|
`logit_lens` of shape (N,).
|
|
decoder:
|
|
It is the prediction network in the paper. Its input shape
|
|
is (N, U) and its output shape is (N, U, decoder_dim).
|
|
It should contain one attribute: `blank_id`.
|
|
joiner:
|
|
It has two inputs with shapes: (N, T, encoder_dim) and
|
|
(N, U, decoder_dim).
|
|
Its output shape is (N, T, U, vocab_size). Note that its output
|
|
contains unnormalized probs, i.e., not processed by log-softmax.
|
|
num_codebooks:
|
|
Used by distillation loss.
|
|
"""
|
|
super().__init__()
|
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
|
assert hasattr(decoder, "blank_id")
|
|
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
self.joiner = joiner
|
|
|
|
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5)
|
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
|
|
|
|
from icefall import is_module_available
|
|
|
|
if not is_module_available("multi_quantization"):
|
|
raise ValueError("Please 'pip install multi_quantization' first.")
|
|
|
|
from multi_quantization.prediction import JointCodebookLoss
|
|
|
|
if num_codebooks > 0:
|
|
self.codebook_loss_net = JointCodebookLoss(
|
|
predictor_channels=encoder_dim,
|
|
num_codebooks=num_codebooks,
|
|
is_joint=False,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
x_lens: torch.Tensor,
|
|
y: k2.RaggedTensor,
|
|
prune_range: int = 5,
|
|
am_scale: float = 0.0,
|
|
lm_scale: float = 0.0,
|
|
warmup: float = 1.0,
|
|
reduction: str = "sum",
|
|
codebook_indexes: torch.Tensor = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
x:
|
|
A 3-D tensor of shape (N, T, C).
|
|
x_lens:
|
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
|
before padding.
|
|
y:
|
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
|
utterance.
|
|
prune_range:
|
|
The prune range for rnnt loss, it means how many symbols(context)
|
|
we are considering for each frame to compute the loss.
|
|
am_scale:
|
|
The scale to smooth the loss with am (output of encoder network)
|
|
part
|
|
lm_scale:
|
|
The scale to smooth the loss with lm (output of predictor network)
|
|
part
|
|
warmup:
|
|
A value warmup >= 0 that determines which modules are active, values
|
|
warmup > 1 "are fully warmed up" and all modules will be active.
|
|
reduction:
|
|
"sum" to sum the losses over all utterances in the batch.
|
|
"none" to return the loss in a 1-D tensor for each utterance
|
|
in the batch.
|
|
codebook_indexes:
|
|
codebook_indexes extracted from a teacher model.
|
|
Returns:
|
|
Return the transducer loss.
|
|
|
|
Note:
|
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
|
the form:
|
|
lm_scale * lm_probs + am_scale * am_probs +
|
|
(1-lm_scale-am_scale) * combined_probs
|
|
"""
|
|
assert reduction in ("sum", "none"), reduction
|
|
assert x.ndim == 3, x.shape
|
|
assert x_lens.ndim == 1, x_lens.shape
|
|
assert y.num_axes == 2, y.num_axes
|
|
|
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
|
|
|
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
|
encoder_out = layer_results[-1]
|
|
middle_layer_output = layer_results[0]
|
|
if self.training and codebook_indexes is not None:
|
|
assert hasattr(self, "codebook_loss_net")
|
|
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
|
|
codebook_indexes = self.concat_successive_codebook_indexes(
|
|
middle_layer_output, codebook_indexes
|
|
)
|
|
codebook_loss = self.codebook_loss_net(
|
|
middle_layer_output, codebook_indexes
|
|
)
|
|
else:
|
|
# when codebook index is not available.
|
|
codebook_loss = None
|
|
|
|
assert torch.all(x_lens > 0)
|
|
|
|
# Now for the decoder, i.e., the prediction network
|
|
row_splits = y.shape.row_splits(1)
|
|
y_lens = row_splits[1:] - row_splits[:-1]
|
|
|
|
blank_id = self.decoder.blank_id
|
|
sos_y = add_sos(y, sos_id=blank_id)
|
|
|
|
# sos_y_padded: [B, S + 1], start with SOS.
|
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
|
|
|
# decoder_out: [B, S + 1, decoder_dim]
|
|
decoder_out = self.decoder(sos_y_padded)
|
|
|
|
# Note: y does not start with SOS
|
|
# y_padded : [B, S]
|
|
y_padded = y.pad(mode="constant", padding_value=0)
|
|
|
|
y_padded = y_padded.to(torch.int64)
|
|
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
|
boundary[:, 2] = y_lens
|
|
boundary[:, 3] = x_lens
|
|
|
|
lm = self.simple_lm_proj(decoder_out)
|
|
am = self.simple_am_proj(encoder_out)
|
|
|
|
with torch.amp.autocast("cuda", enabled=False):
|
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
|
lm=lm.float(),
|
|
am=am.float(),
|
|
symbols=y_padded,
|
|
termination_symbol=blank_id,
|
|
lm_only_scale=lm_scale,
|
|
am_only_scale=am_scale,
|
|
boundary=boundary,
|
|
reduction=reduction,
|
|
return_grad=True,
|
|
)
|
|
|
|
# ranges : [B, T, prune_range]
|
|
ranges = k2.get_rnnt_prune_ranges(
|
|
px_grad=px_grad,
|
|
py_grad=py_grad,
|
|
boundary=boundary,
|
|
s_range=prune_range,
|
|
)
|
|
|
|
# am_pruned : [B, T, prune_range, encoder_dim]
|
|
# lm_pruned : [B, T, prune_range, decoder_dim]
|
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
|
am=self.joiner.encoder_proj(encoder_out),
|
|
lm=self.joiner.decoder_proj(decoder_out),
|
|
ranges=ranges,
|
|
)
|
|
|
|
# logits : [B, T, prune_range, vocab_size]
|
|
|
|
# project_input=False since we applied the decoder's input projections
|
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
|
|
|
with torch.amp.autocast("cuda", enabled=False):
|
|
pruned_loss = k2.rnnt_loss_pruned(
|
|
logits=logits.float(),
|
|
symbols=y_padded,
|
|
ranges=ranges,
|
|
termination_symbol=blank_id,
|
|
boundary=boundary,
|
|
reduction=reduction,
|
|
)
|
|
|
|
return (simple_loss, pruned_loss, codebook_loss)
|
|
|
|
@staticmethod
|
|
def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes):
|
|
# Output rate of hubert is 50 frames per second,
|
|
# while that of current encoder is 25.
|
|
# Following code handling two issues:
|
|
# 1.
|
|
# Roughly speaking, to generate another frame output,
|
|
# hubert needes extra two frames,
|
|
# while current encoder needs extra four frames.
|
|
# Suppose there are only extra three frames provided,
|
|
# hubert will generate another frame while current encoder does nothing.
|
|
# 2.
|
|
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
|
|
# learns from 50 frames teacher output, two successive frames of teacher model
|
|
# output is concatenated together.
|
|
t_expected = middle_layer_output.shape[1]
|
|
N, T, C = codebook_indexes.shape
|
|
|
|
# Handling issue 1.
|
|
if T >= t_expected * 2:
|
|
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
|
|
# Handling issue 2.
|
|
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
|
|
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
|
|
return codebook_indexes
|