add CTC loss option in zipformer recipe

This commit is contained in:
yaozengwei 2023-06-02 10:40:10 +08:00
parent 7a604057f9
commit 96233e35a1
7 changed files with 392 additions and 47 deletions

View File

@ -116,7 +116,7 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -694,7 +694,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:

View File

@ -161,7 +161,7 @@ from typing import List, Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -408,7 +408,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:

View File

@ -44,7 +44,7 @@ import sentencepiece as spm
import torch import torch
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,
@ -140,7 +140,7 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
print("About to create model") print("About to create model")
model = get_transducer_model(params) model = get_model(params)
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple
import k2 import k2
import torch import torch
@ -37,7 +38,7 @@ class Transducer(nn.Module):
joiner: nn.Module, joiner: nn.Module,
encoder_dim: int, encoder_dim: int,
decoder_dim: int, decoder_dim: int,
joiner_dim: int, joiner_dim: int, # TODO: remove it
vocab_size: int, vocab_size: int,
): ):
""" """
@ -69,16 +70,8 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.simple_am_proj = ScaledLinear( self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_scale=0.25)
encoder_dim, self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, initial_scale=0.25)
vocab_size,
initial_scale=0.25,
)
self.simple_lm_proj = ScaledLinear(
decoder_dim,
vocab_size,
initial_scale=0.25,
)
def forward( def forward(
self, self,
@ -215,3 +208,310 @@ class Transducer(nn.Module):
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)
class AsrModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None,
encoder_dim: int = 384,
decoder_dim: int = 512,
vocab_size: int = 500,
use_transducer: bool = True,
use_ctc: bool = False,
):
"""A joint CTC & Transducer model.
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
It is used when use_transducer is True.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
It is used when use_transducer is True.
use_transducer:
Whether use transducer head. Default: True.
use_ctc:
Whether use CTC head. Default: False.
"""
super().__init__()
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.use_transducer = use_transducer
if use_transducer:
# Modules for Transducer head
assert decoder is not None
assert hasattr(decoder, "blank_id")
assert joiner is not None
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(
encoder_dim, vocab_size, initial_scale=0.25
)
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
self.use_ctc = use_ctc
if use_ctc:
# Modules for CTC head
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets,
input_lengths=encoder_out_lens,
target_lengths=target_lengths,
reduction="sum",
)
return ctc_loss
def forward_transducer(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
y: k2.RaggedTensor,
y_lens: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
"""
# Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(encoder_out.size(0), 4),
dtype=torch.int64,
device=encoder_out.device,
)
boundary[:, 2] = y_lens
boundary[:, 3] = encoder_out_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return simple_loss, pruned_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss)
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
# Compute encoder outputs
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
if self.use_transducer:
# Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=x_lens,
y=y,
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,
lm_scale=lm_scale,
)
else:
simple_loss = 0
pruned_loss = 0
if self.use_ctc:
# Compute CTC loss
targets = [t for tokens in y.tolist() for t in tokens]
# of shape (sum(y_lens),)
targets = torch.tensor(targets, device=x.device, dtype=torch.int64)
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=x_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = 0
return simple_loss, pruned_loss, ctc_loss

View File

@ -122,7 +122,7 @@ from beam_search import (
) )
from icefall.utils import make_pad_mask from icefall.utils import make_pad_mask
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_model
def get_parser(): def get_parser():
@ -284,7 +284,7 @@ def main():
), "left_context_frames should be one value in decoding." ), "left_context_frames should be one value in decoding."
logging.info("Creating model") logging.info("Creating model")
model = get_transducer_model(params) model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")

View File

@ -51,7 +51,7 @@ from streaming_beam_search import (
) )
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -756,7 +756,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:

View File

@ -71,6 +71,7 @@ from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
@ -242,6 +243,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"chunk left-context frames will be chosen randomly from this list; else not relevant." "chunk left-context frames will be chosen randomly from this list; else not relevant."
) )
parser.add_argument(
"--use-transducer",
type=str2bool,
default=True,
help="If True, use Transducer head.",
)
parser.add_argument(
"--use-ctc",
type=str2bool,
default=False,
help="If True, use CTC head.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -385,6 +400,13 @@ def get_parser():
"with this parameter before adding to the final loss.", "with this parameter before adding to the final loss.",
) )
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument( parser.add_argument(
"--seed", "--seed",
type=int, type=int,
@ -585,21 +607,33 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
return joiner return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
f"but gotten params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
if params.use_transducer:
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
else:
decoder = None
joiner = None
model = Transducer( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(','))), encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
) )
return model return model
@ -728,7 +762,7 @@ def compute_loss(
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute loss given the model and its inputs.
Args: Args:
params: params:
@ -766,7 +800,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, ctc_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -775,6 +809,9 @@ def compute_loss(
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
) )
loss = 0.0
if params.use_transducer:
s = params.simple_loss_scale s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start # take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step. # to params.simple_loss scale by warm_step.
@ -786,12 +823,14 @@ def compute_loss(
1.0 if batch_idx_train >= warm_step 1.0 if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step) else 0.1 + 0.9 * (batch_idx_train / warm_step)
) )
loss += (
loss = (
simple_loss_scale * simple_loss + simple_loss_scale * simple_loss +
pruned_loss_scale * pruned_loss pruned_loss_scale * pruned_loss
) )
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
@ -803,8 +842,11 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info return loss, info
@ -1081,10 +1123,13 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if not params.use_transducer:
params.ctc_loss_scale = 1.0
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")