191 lines
6.4 KiB
Python

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from icefall.utils import add_eos, add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
backward_decoder: nn.Module,
joiner: nn.Module,
backward_joiner: nn.Module,
prune_range: int = 3,
lm_scale: float = 0.0,
am_scale: float = 0.0,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
backward_decoder:
Almost the same as decoder, except that it uses right context and
the decoder uses left context.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
backward_joiner:
The same as joiner, it intends for backward_decoder.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.backward_decoder = backward_decoder
self.joiner = joiner
self.backward_joiner = backward_joiner
self.prune_range = prune_range
self.lm_scale = lm_scale
self.am_scale = am_scale
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
# calculate prune ranges
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
decoder_out,
encoder_out,
y_padded.to(torch.int64),
blank_id,
lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale,
boundary=boundary,
return_grad=True,
)
ranges = k2.get_rnnt_prune_ranges(
px_grad, py_grad, boundary, self.prune_range
)
# forward loss
am_pruned, lm_pruned = k2.do_rnnt_pruning(
encoder_out, decoder_out, ranges
)
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits, y_padded.to(torch.int64), ranges, blank_id, boundary
)
eos_y = add_eos(y, eos_id=blank_id)
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id)
# backward loss
assert self.backward_decoder is not None
assert self.backward_joiner is not None
backward_decoder_out = self.backward_decoder(eos_y_padded)
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
encoder_out, backward_decoder_out, ranges
)
backward_logits = self.backward_joiner(
backward_am_pruned, backward_lm_pruned
)
backward_pruned_loss = k2.rnnt_loss_pruned(
backward_logits,
y_padded.to(torch.int64),
ranges,
blank_id,
boundary,
)
return (
-torch.sum(simple_loss),
-torch.sum(pruned_loss),
-torch.sum(backward_pruned_loss),
)