Add repeat_param to embedding_out during training and decoding

This commit is contained in:
yifanyang 2023-02-02 17:33:52 +08:00
parent 45b68e1df6
commit c3e01e141f
5 changed files with 154 additions and 46 deletions

View File

@ -542,7 +542,9 @@ def greedy_search(
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
).reshape(1, context_size) ).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) k = torch.zeros(1, 1, device=device, dtype=torch.int64)
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -586,7 +588,20 @@ def greedy_search(
1, context_size 1, context_size
) )
decoder_out = model.decoder(decoder_input, need_pad=False) c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape(
1, context_size + 1
)
k[:, 0] = torch.sum(
(
c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
).int(),
dim=1,
keepdim=True,
)
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
sym_per_utt += 1 sym_per_utt += 1
@ -594,7 +609,7 @@ def greedy_search(
else: else:
sym_per_frame = 0 sym_per_frame = 0
t += 1 t += 1
hyp = hyp[context_size:] # remove blanks hyp = hyp[context_size :] # remove blanks
if not return_timestamps: if not return_timestamps:
return hyp return hyp

View File

@ -20,36 +20,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 30 \
--avg 15 \ --avg 8 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best) (4) fast beam search (one best)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 20.0 \ --beam 20.0 \
@ -57,10 +57,10 @@ Usage:
--max-states 64 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 20.0 \ --beam 20.0 \
@ -70,10 +70,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(6) fast beam search (nbest oracle WER) (6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \ --beam 20.0 \
@ -83,10 +83,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG) (7) fast beam search (with LG)
./pruned_transducer_stateless7/decode.py \ ./pruned_transducer_stateless9/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \ --decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \ --beam 20.0 \
@ -223,7 +223,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless7/exp", default="pruned_transducer_stateless9/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -73,14 +73,50 @@ class Decoder(nn.Module):
bias=False, bias=False,
) )
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: self.repeat_param = nn.Parameter(torch.randn(decoder_dim))
def _add_repeat_param(
self,
embedding_out: torch.Tensor,
k: torch.Tensor,
is_training: bool = True,
) -> torch.Tensor:
"""
Add the repeat parameter to the embedding_out tensor.
Args:
embedding_out:
A tensor of shape (N, U, decoder_dim).
k:
A tensor of shape (N, U).
Should be (N, S + 1) during training.
Should be (N, 1) during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
return embedding_out + torch.matmul(
(k / (1 + k)).unsqueeze(2),
self.repeat_param.unsqueeze(0),
)
def forward(
self,
y: torch.Tensor,
k: torch.Tensor,
need_pad: bool = True,
) -> torch.Tensor:
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U). A 2-D tensor of shape (N, U).
k:
A 2-D tensor, statistic given the context_size with respect to utt.
Should be (N, S + 1) during training.
Should be (N, 1) during inference.
need_pad: need_pad:
True to left pad the input. Should be True during training. Whether to left pad the input.
False to not pad the input. Should be False during inference. Should be True during training.
Should be False during inference.
Returns: Returns:
Return a tensor of shape (N, U, decoder_dim). Return a tensor of shape (N, U, decoder_dim).
""" """
@ -90,7 +126,7 @@ class Decoder(nn.Module):
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1: if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
@ -98,5 +134,10 @@ class Decoder(nn.Module):
assert embedding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out) embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self._add_repeat_param(
embedding_out=embedding_out,
k=k,
)
embedding_out = F.relu(embedding_out) embedding_out = F.relu(embedding_out)
return embedding_out return embedding_out

View File

@ -20,10 +20,11 @@ import random
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt from scaling import penalize_abs_values_gt
from icefall.utils import add_sos from icefall.utils import add_sos, make_pad_mask
class Transducer(nn.Module): class Transducer(nn.Module):
@ -75,7 +76,9 @@ class Transducer(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: torch.Tensor,
y_lens: torch.Tensor,
k: torch.Tensor,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
@ -88,8 +91,14 @@ class Transducer(nn.Module):
A 1-D tensor of shape (N,). It contains the number of frames in `x` A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding. before padding.
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A 2-D tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
y_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `y`
before padding.
k:
A statistic given the context_size with respect to utt.
A 2-D tensor of shape (N, U).
prune_range: prune_range:
The prune range for rnnt loss, it means how many symbols(context) The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss. we are considering for each frame to compute the loss.
@ -110,31 +119,24 @@ class Transducer(nn.Module):
""" """
assert x.ndim == 3, x.shape assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes assert len(y.shape) == 2, len(y.shape)
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.size(0)
encoder_out, x_lens = self.encoder(x, x_lens) encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS. # sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id)
# decoder_out: [B, S + 1, decoder_dim] # decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded, k)
# Note: y does not start with SOS # Note: y_padded does not start with SOS
# y_padded : [B, S] # y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.to(torch.int64)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens

View File

@ -22,22 +22,22 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \ ./pruned_transducer_stateless9/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \ --exp-dir pruned_transducer_stateless9/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless7/train.py \ ./pruned_transducer_stateless9/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \ --exp-dir pruned_transducer_stateless9/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -58,6 +58,7 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -69,6 +70,7 @@ from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -235,7 +237,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless7/exp", default="pruned_transducer_stateless9/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -625,6 +627,41 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def compute_k(
y: torch.Tensor,
context_size: int = 2,
blank_id: int = 0,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
Returns:
Return a tensor of shape (N, U).
"""
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
k = torch.zeros_like(y)
for i in range(2, y.size(1) - 1):
k[:, i : i + 1] = torch.where(
y[:, i : i + 1] != 0,
torch.sum(
(
y[:, i - context_size : i]
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
).int(),
dim=1,
keepdim=True,
),
y[:, i : i + 1],
)
return k
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -675,13 +712,26 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y_lens = torch.tensor(list(map(len, y))).to(device)
y = list(map(torch.tensor, y))
y = pad_sequence(y, batch_first=True) # [B, S]
k = compute_k(
y,
params.context_size,
model.module.decoder.blank_id
if isinstance(model, DDP)
else model.decoder.blank_id,
).to(device)
y = y.to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
y_lens=y_lens,
k=k,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,