add option for choosing rnnt type

This commit is contained in:
Desh Raj 2023-06-26 11:47:23 -04:00
parent 4325bb20b9
commit b24be1aa9d
3 changed files with 280 additions and 3 deletions

View File

@ -82,6 +82,52 @@ avg=22
A pre-trained model and decoding logs can be found at <https://huggingface.co/desh2608/icefall-asr-tedlium3-zipformer> A pre-trained model and decoding logs can be found at <https://huggingface.co/desh2608/icefall-asr-tedlium3-zipformer>
#### 2023-06-26 (transducer topology)
**Modified transducer**
```
./zipformer/train.py \
--use-fp16 true \
--world-size 4 \
--num-epochs 50 \
--start-epoch 0 \
--exp-dir zipformer/exp \
--max-duration 1000 \
--rnnt-type modified
```
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 6.32 | 5.83 | --epoch 50, --avg 22, --max-duration 500 |
| beam search (beam size 4) | 6.56 | 5.95 | --epoch 50, --avg 22, --max-duration 500 |
| modified beam search (beam size 4) | 6.16 | 5.79 | --epoch 50, --avg 22, --max-duration 500 |
| fast beam search (set as default) | 6.30 | 5.89 | --epoch 50, --avg 22, --max-duration 500 |
A pre-trained model and decoding logs can be found at .
**Constrained transducer**
```
./zipformer/train.py \
--use-fp16 true \
--world-size 4 \
--num-epochs 50 \
--start-epoch 0 \
--exp-dir zipformer/exp \
--max-duration 1000 \
--rnnt-type constrained
```
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 6.58 | 6.20 | --epoch 50, --avg 22, --max-duration 500 |
| beam search (beam size 4) | 6.34 | 5.92 | --epoch 50, --avg 22, --max-duration 500 |
| modified beam search (beam size 4) | 6.38 | 5.84 | --epoch 50, --avg 22, --max-duration 500 |
| fast beam search (set as default) | 6.68 | 6.29 | --epoch 50, --avg 22, --max-duration 500 |
A pre-trained model and decoding logs can be found at .
### TedLium3 BPE training results (Conformer-CTC 2) ### TedLium3 BPE training results (Conformer-CTC 2)
#### [conformer_ctc2](./conformer_ctc2) #### [conformer_ctc2](./conformer_ctc2)

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1,223 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.encoder_embed = encoder_embed
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim,
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
rnnt_type: str = "regular",
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
rnnt_type:
The type of label topology to use for the transducer loss. One of "regular",
"modified", or "constrained".
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
rnnt_type=rnnt_type,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
rnnt_type=rnnt_type,
)
return (simple_loss, pruned_loss)

View File

@ -68,7 +68,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids
from model import AsrModel from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
@ -354,6 +354,13 @@ def get_parser():
"we are using to compute the loss", "we are using to compute the loss",
) )
parser.add_argument(
"--rnnt-type",
type=str,
default="regular",
choices=["regular", "modified", "constrained"],
)
parser.add_argument( parser.add_argument(
"--lm-scale", "--lm-scale",
type=float, type=float,
@ -585,13 +592,14 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
model = AsrModel( model = Transducer(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(","))), encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
) )
return model return model
@ -762,6 +770,7 @@ def compute_loss(
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
rnnt_type=params.rnnt_type,
) )
s = params.simple_loss_scale s = params.simple_loss_scale