Predicting blanks via gradients from the trivial joiner.

This commit is contained in:
Fangjun Kuang 2022-03-31 20:12:41 +08:00
parent 239a8fa1f2
commit feb526c2a4
4 changed files with 163 additions and 11 deletions

View File

@ -0,0 +1,65 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from icefall.utils import make_pad_mask
class BlankPredictor(nn.Module):
def __init__(self, encoder_out_dim: int):
"""
Args:
Output dimension of the encoder network.
"""
super().__init__()
self.linear = nn.Linear(in_features=encoder_out_dim, out_features=1)
self.loss_func = nn.BCEWithLogitsLoss(reduction="none")
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
soft_target: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, encoder_out_dim) from the output of
the encoder network.
x_lens:
A 1-D tensor of shape (N,) containing the number of valid frames
for each element in `x`.
soft_target:
A 2-D tensor of shape (N, T) containing the soft label of each frame
in `x`.
"""
assert x.ndim == 3, x.shape
assert soft_target.ndim == 2, soft_target.shape
assert x.shape[:2] == soft_target.shape[:2], (
x.shape,
soft_target.shape,
)
logits = self.linear(x).squeeze(-1)
mask = make_pad_mask(x_lens)
loss = self.loss_func(logits, soft_target)
loss.masked_fill_(mask, 0)
return loss.sum()

View File

@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
from typing import Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -33,6 +35,7 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
blank_predictor: nn.Module,
): ):
""" """
Args: Args:
@ -49,6 +52,9 @@ class Transducer(nn.Module):
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
blank_predictor:
The model to predict blanks from the encoder output. See also
`./blank_predictor.py`.
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
@ -57,6 +63,7 @@ class Transducer(nn.Module):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.blank_predictor = blank_predictor
def forward( def forward(
self, self,
@ -66,7 +73,7 @@ class Transducer(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -87,7 +94,11 @@ class Transducer(nn.Module):
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
Returns: Returns:
Return the transducer loss. Return a tuple containing:
- The loss for the "trivial" joiner
- The loss for the non-linear joiner
- The loss for predicting the blank token
Note: Note:
Regarding am_scale & lm_scale, it will make the loss-function one of Regarding am_scale & lm_scale, it will make the loss-function one of
@ -101,8 +112,8 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens) encoder_out, encoder_out_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0) assert torch.all(encoder_out_lens > 0)
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
@ -126,7 +137,7 @@ class Transducer(nn.Module):
(x.size(0), 4), dtype=torch.int64, device=x.device (x.size(0), 4), dtype=torch.int64, device=x.device
) )
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = encoder_out_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out, lm=decoder_out,
@ -139,6 +150,19 @@ class Transducer(nn.Module):
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
#
# px_grad shape: (B, y_lens.max(), T+1)
# Note: In the paper, we use y'(t, u)
#
non_blank_occuptation = px_grad[:, :, :-1].sum(dim=1)
non_blank_occuptation = torch.clamp(non_blank_occuptation, min=0, max=1)
blank_occupation = 1 - non_blank_occuptation
blank_prediction_loss = self.blank_predictor(
encoder_out,
encoder_out_lens,
blank_occupation,
)
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
@ -166,4 +190,4 @@ class Transducer(nn.Module):
reduction="sum", reduction="sum",
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss, blank_prediction_loss)

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless_2/test_blank_predictor.py
"""
import torch
from blank_predictor import BlankPredictor
def test_blank_predictor():
dim = 10
predictor = BlankPredictor(encoder_out_dim=dim)
x = torch.rand(4, 3, dim)
x_lens = torch.tensor([1, 3, 2, 3], dtype=torch.int32)
y = torch.rand(4, 3)
loss = predictor(x, x_lens, y)
print(loss)
def main():
test_blank_predictor()
if __name__ == "__main__":
main()

View File

@ -21,11 +21,11 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \ ./pruned_transducer_stateless-2/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \ --exp-dir pruned_transducer_stateless-2/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
""" """
@ -44,6 +44,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from blank_predictor import BlankPredictor
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -128,7 +129,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless/exp", default="pruned_transducer_stateless-2/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -191,6 +192,13 @@ def get_parser():
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss.",
) )
parser.add_argument(
"--blank-prediction-scale",
type=float,
default=0.1,
help="Scale to use for the blank prediction loss",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -333,15 +341,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner return joiner
def get_blank_prediction_model(params: AttributeDict) -> nn.Module:
blank_predictor = BlankPredictor(encoder_out_dim=params.vocab_size)
return blank_predictor
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
blank_predictor = get_blank_prediction_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
blank_predictor=blank_predictor,
) )
return model return model
@ -484,7 +499,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, blank_prediction_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -492,7 +507,11 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss
+ params.blank_prediction_scale * blank_prediction_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -507,6 +526,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
info["blank_prediction_loss"] = blank_prediction_loss.detach().cpu().item()
return loss, info return loss, info