mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Predicting blanks via gradients from the trivial joiner.
This commit is contained in:
parent
239a8fa1f2
commit
feb526c2a4
@ -0,0 +1,65 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
|
||||
class BlankPredictor(nn.Module):
|
||||
def __init__(self, encoder_out_dim: int):
|
||||
"""
|
||||
Args:
|
||||
Output dimension of the encoder network.
|
||||
"""
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_features=encoder_out_dim, out_features=1)
|
||||
|
||||
self.loss_func = nn.BCEWithLogitsLoss(reduction="none")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
soft_target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, encoder_out_dim) from the output of
|
||||
the encoder network.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,) containing the number of valid frames
|
||||
for each element in `x`.
|
||||
soft_target:
|
||||
A 2-D tensor of shape (N, T) containing the soft label of each frame
|
||||
in `x`.
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert soft_target.ndim == 2, soft_target.shape
|
||||
|
||||
assert x.shape[:2] == soft_target.shape[:2], (
|
||||
x.shape,
|
||||
soft_target.shape,
|
||||
)
|
||||
logits = self.linear(x).squeeze(-1)
|
||||
mask = make_pad_mask(x_lens)
|
||||
|
||||
loss = self.loss_func(logits, soft_target)
|
||||
loss.masked_fill_(mask, 0)
|
||||
|
||||
return loss.sum()
|
@ -15,6 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -33,6 +35,7 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
blank_predictor: nn.Module,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -49,6 +52,9 @@ class Transducer(nn.Module):
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
blank_predictor:
|
||||
The model to predict blanks from the encoder output. See also
|
||||
`./blank_predictor.py`.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
@ -57,6 +63,7 @@ class Transducer(nn.Module):
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
self.blank_predictor = blank_predictor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -66,7 +73,7 @@ class Transducer(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -87,7 +94,11 @@ class Transducer(nn.Module):
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
Return a tuple containing:
|
||||
|
||||
- The loss for the "trivial" joiner
|
||||
- The loss for the non-linear joiner
|
||||
- The loss for predicting the blank token
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
@ -101,8 +112,8 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(encoder_out_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
@ -126,7 +137,7 @@ class Transducer(nn.Module):
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=decoder_out,
|
||||
@ -139,6 +150,19 @@ class Transducer(nn.Module):
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
#
|
||||
# px_grad shape: (B, y_lens.max(), T+1)
|
||||
# Note: In the paper, we use y'(t, u)
|
||||
#
|
||||
non_blank_occuptation = px_grad[:, :, :-1].sum(dim=1)
|
||||
non_blank_occuptation = torch.clamp(non_blank_occuptation, min=0, max=1)
|
||||
blank_occupation = 1 - non_blank_occuptation
|
||||
|
||||
blank_prediction_loss = self.blank_predictor(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
blank_occupation,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
@ -166,4 +190,4 @@ class Transducer(nn.Module):
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
return (simple_loss, pruned_loss, blank_prediction_loss)
|
||||
|
43
egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py
Executable file
43
egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py
Executable file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless_2/test_blank_predictor.py
|
||||
"""
|
||||
import torch
|
||||
from blank_predictor import BlankPredictor
|
||||
|
||||
|
||||
def test_blank_predictor():
|
||||
dim = 10
|
||||
predictor = BlankPredictor(encoder_out_dim=dim)
|
||||
x = torch.rand(4, 3, dim)
|
||||
x_lens = torch.tensor([1, 3, 2, 3], dtype=torch.int32)
|
||||
y = torch.rand(4, 3)
|
||||
loss = predictor(x, x_lens, y)
|
||||
print(loss)
|
||||
|
||||
|
||||
def main():
|
||||
test_blank_predictor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -21,11 +21,11 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless/train.py \
|
||||
./pruned_transducer_stateless-2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--exp-dir pruned_transducer_stateless-2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
"""
|
||||
@ -44,6 +44,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from blank_predictor import BlankPredictor
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -128,7 +129,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless/exp",
|
||||
default="pruned_transducer_stateless-2/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -191,6 +192,13 @@ def get_parser():
|
||||
"with this parameter before adding to the final loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-prediction-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="Scale to use for the blank prediction loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
@ -333,15 +341,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
return joiner
|
||||
|
||||
|
||||
def get_blank_prediction_model(params: AttributeDict) -> nn.Module:
|
||||
blank_predictor = BlankPredictor(encoder_out_dim=params.vocab_size)
|
||||
return blank_predictor
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
blank_predictor = get_blank_prediction_model(params)
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
blank_predictor=blank_predictor,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -484,7 +499,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
simple_loss, pruned_loss, blank_prediction_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -492,7 +507,11 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss
|
||||
+ params.blank_prediction_scale * blank_prediction_loss
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -507,6 +526,7 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
info["blank_prediction_loss"] = blank_prediction_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user