mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Predicting blanks via gradients from the trivial joiner.
This commit is contained in:
parent
239a8fa1f2
commit
feb526c2a4
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
class BlankPredictor(nn.Module):
|
||||||
|
def __init__(self, encoder_out_dim: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Output dimension of the encoder network.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(in_features=encoder_out_dim, out_features=1)
|
||||||
|
|
||||||
|
self.loss_func = nn.BCEWithLogitsLoss(reduction="none")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
soft_target: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, encoder_out_dim) from the output of
|
||||||
|
the encoder network.
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,) containing the number of valid frames
|
||||||
|
for each element in `x`.
|
||||||
|
soft_target:
|
||||||
|
A 2-D tensor of shape (N, T) containing the soft label of each frame
|
||||||
|
in `x`.
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert soft_target.ndim == 2, soft_target.shape
|
||||||
|
|
||||||
|
assert x.shape[:2] == soft_target.shape[:2], (
|
||||||
|
x.shape,
|
||||||
|
soft_target.shape,
|
||||||
|
)
|
||||||
|
logits = self.linear(x).squeeze(-1)
|
||||||
|
mask = make_pad_mask(x_lens)
|
||||||
|
|
||||||
|
loss = self.loss_func(logits, soft_target)
|
||||||
|
loss.masked_fill_(mask, 0)
|
||||||
|
|
||||||
|
return loss.sum()
|
@ -15,6 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -33,6 +35,7 @@ class Transducer(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
|
blank_predictor: nn.Module,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -49,6 +52,9 @@ class Transducer(nn.Module):
|
|||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, C). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
|
blank_predictor:
|
||||||
|
The model to predict blanks from the encoder output. See also
|
||||||
|
`./blank_predictor.py`.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -57,6 +63,7 @@ class Transducer(nn.Module):
|
|||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
self.blank_predictor = blank_predictor
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -66,7 +73,7 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x:
|
x:
|
||||||
@ -87,7 +94,11 @@ class Transducer(nn.Module):
|
|||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return a tuple containing:
|
||||||
|
|
||||||
|
- The loss for the "trivial" joiner
|
||||||
|
- The loss for the non-linear joiner
|
||||||
|
- The loss for predicting the blank token
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
@ -101,8 +112,8 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(encoder_out_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
@ -126,7 +137,7 @@ class Transducer(nn.Module):
|
|||||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
)
|
)
|
||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = encoder_out_lens
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=decoder_out,
|
lm=decoder_out,
|
||||||
@ -139,6 +150,19 @@ class Transducer(nn.Module):
|
|||||||
reduction="sum",
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
#
|
||||||
|
# px_grad shape: (B, y_lens.max(), T+1)
|
||||||
|
# Note: In the paper, we use y'(t, u)
|
||||||
|
#
|
||||||
|
non_blank_occuptation = px_grad[:, :, :-1].sum(dim=1)
|
||||||
|
non_blank_occuptation = torch.clamp(non_blank_occuptation, min=0, max=1)
|
||||||
|
blank_occupation = 1 - non_blank_occuptation
|
||||||
|
|
||||||
|
blank_prediction_loss = self.blank_predictor(
|
||||||
|
encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
blank_occupation,
|
||||||
|
)
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
@ -166,4 +190,4 @@ class Transducer(nn.Module):
|
|||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss, blank_prediction_loss)
|
||||||
|
43
egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py
Executable file
43
egs/librispeech/ASR/pruned_transducer_stateless-2/test_blank_predictor.py
Executable file
@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./pruned_transducer_stateless_2/test_blank_predictor.py
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from blank_predictor import BlankPredictor
|
||||||
|
|
||||||
|
|
||||||
|
def test_blank_predictor():
|
||||||
|
dim = 10
|
||||||
|
predictor = BlankPredictor(encoder_out_dim=dim)
|
||||||
|
x = torch.rand(4, 3, dim)
|
||||||
|
x_lens = torch.tensor([1, 3, 2, 3], dtype=torch.int32)
|
||||||
|
y = torch.rand(4, 3)
|
||||||
|
loss = predictor(x, x_lens, y)
|
||||||
|
print(loss)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_blank_predictor()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -21,11 +21,11 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless/train.py \
|
./pruned_transducer_stateless-2/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless/exp \
|
--exp-dir pruned_transducer_stateless-2/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
"""
|
"""
|
||||||
@ -44,6 +44,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from blank_predictor import BlankPredictor
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
@ -128,7 +129,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless/exp",
|
default="pruned_transducer_stateless-2/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -191,6 +192,13 @@ def get_parser():
|
|||||||
"with this parameter before adding to the final loss.",
|
"with this parameter before adding to the final loss.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-prediction-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="Scale to use for the blank prediction loss",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -333,15 +341,22 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
|
def get_blank_prediction_model(params: AttributeDict) -> nn.Module:
|
||||||
|
blank_predictor = BlankPredictor(encoder_out_dim=params.vocab_size)
|
||||||
|
return blank_predictor
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
blank_predictor = get_blank_prediction_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
blank_predictor=blank_predictor,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -484,7 +499,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, blank_prediction_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -492,7 +507,11 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss
|
||||||
|
+ params.blank_prediction_scale * blank_prediction_loss
|
||||||
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -507,6 +526,7 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
info["blank_prediction_loss"] = blank_prediction_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user