Predicting blanks via gradients from the trivial joiner.

This commit is contained in:
Fangjun Kuang 2022-03-31 20:12:41 +08:00
parent 239a8fa1f2
commit feb526c2a4
4 changed files with 163 additions and 11 deletions

View File

@ -0,0 +1,65 @@
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from icefall.utils import make_pad_mask
class BlankPredictor(nn.Module):
def __init__(self, encoder_out_dim: int):
"""
Args:
Output dimension of the encoder network.
"""
super().__init__()
self.linear = nn.Linear(in_features=encoder_out_dim, out_features=1)
self.loss_func = nn.BCEWithLogitsLoss(reduction="none")
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
soft_target: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, encoder_out_dim) from the output of
the encoder network.
x_lens:
A 1-D tensor of shape (N,) containing the number of valid frames
for each element in `x`.
soft_target:
A 2-D tensor of shape (N, T) containing the soft label of each frame
in `x`.
"""
assert x.ndim == 3, x.shape
assert soft_target.ndim == 2, soft_target.shape
assert x.shape[:2] == soft_target.shape[:2], (
x.shape,
soft_target.shape,
)
logits = self.linear(x).squeeze(-1)
mask = make_pad_mask(x_lens)
loss = self.loss_func(logits, soft_target)
loss.masked_fill_(mask, 0)
return loss.sum()

View File

@ -15,6 +15,8 @@
# limitations under the License.
from typing import Tuple
import k2
import torch
import torch.nn as nn
@ -33,6 +35,7 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
blank_predictor: nn.Module,
):
"""
Args:
@ -49,6 +52,9 @@ class Transducer(nn.Module):
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
blank_predictor:
The model to predict blanks from the encoder output. See also
`./blank_predictor.py`.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -57,6 +63,7 @@ class Transducer(nn.Module):
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.blank_predictor = blank_predictor
def forward(
self,
@ -66,7 +73,7 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
@ -87,7 +94,11 @@ class Transducer(nn.Module):
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Return a tuple containing:
- The loss for the "trivial" joiner
- The loss for the non-linear joiner
- The loss for predicting the blank token
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
@ -101,8 +112,8 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
assert torch.all(encoder_out_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
@ -126,7 +137,7 @@ class Transducer(nn.Module):
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
boundary[:, 3] = encoder_out_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
@ -139,6 +150,19 @@ class Transducer(nn.Module):
reduction="sum",
return_grad=True,
)
#
# px_grad shape: (B, y_lens.max(), T+1)
# Note: In the paper, we use y'(t, u)
#
non_blank_occuptation = px_grad[:, :, :-1].sum(dim=1)
non_blank_occuptation = torch.clamp(non_blank_occuptation, min=0, max=1)
blank_occupation = 1 - non_blank_occuptation
blank_prediction_loss = self.blank_predictor(
encoder_out,
encoder_out_lens,
blank_occupation,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
@ -166,4 +190,4 @@ class Transducer(nn.Module):
reduction="sum",
)
return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, blank_prediction_loss)

View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./pruned_transducer_stateless_2/test_blank_predictor.py
"""
import torch
from blank_predictor import BlankPredictor
def test_blank_predictor():
dim = 10
predictor = BlankPredictor(encoder_out_dim=dim)
x = torch.rand(4, 3, dim)
x_lens = torch.tensor([1, 3, 2, 3], dtype=torch.int32)
y = torch.rand(4, 3)
loss = predictor(x, x_lens, y)
print(loss)
def main():
test_blank_predictor()
if __name__ == "__main__":
main()

View File

@ -21,11 +21,11 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless/train.py \
./pruned_transducer_stateless-2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--exp-dir pruned_transducer_stateless-2/exp \
--full-libri 1 \
--max-duration 300
"""
@ -44,6 +44,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from blank_predictor import BlankPredictor
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -128,7 +129,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
default="pruned_transducer_stateless-2/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -191,6 +192,13 @@ def get_parser():
"with this parameter before adding to the final loss.",
)
parser.add_argument(
"--blank-prediction-scale",
type=float,
default=0.1,
help="Scale to use for the blank prediction loss",
)
parser.add_argument(
"--seed",
type=int,
@ -333,15 +341,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner
def get_blank_prediction_model(params: AttributeDict) -> nn.Module:
blank_predictor = BlankPredictor(encoder_out_dim=params.vocab_size)
return blank_predictor
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
blank_predictor = get_blank_prediction_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
blank_predictor=blank_predictor,
)
return model
@ -484,7 +499,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, blank_prediction_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -492,7 +507,11 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss
+ params.blank_prediction_scale * blank_prediction_loss
)
assert loss.requires_grad == is_training
@ -507,6 +526,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
info["blank_prediction_loss"] = blank_prediction_loss.detach().cpu().item()
return loss, info