Add repeat_param to embedding_out during training and decoding

This commit is contained in:
yifanyang 2023-02-02 17:33:52 +08:00
parent 45b68e1df6
commit c3e01e141f
5 changed files with 154 additions and 46 deletions

View File

@ -542,7 +542,9 @@ def greedy_search(
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
k = torch.zeros(1, 1, device=device, dtype=torch.int64)
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
@ -586,7 +588,20 @@ def greedy_search(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape(
1, context_size + 1
)
k[:, 0] = torch.sum(
(
c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
).int(),
dim=1,
keepdim=True,
)
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sym_per_utt += 1
@ -594,7 +609,7 @@ def greedy_search(
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks
hyp = hyp[context_size :] # remove blanks
if not return_timestamps:
return hyp

View File

@ -20,36 +20,36 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
./pruned_transducer_stateless9/decode.py \
--epoch 30 \
--avg 8 \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless9/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless9/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless9/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
@ -57,10 +57,10 @@ Usage:
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless9/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
@ -70,10 +70,10 @@ Usage:
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless9/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
@ -83,10 +83,10 @@ Usage:
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless9/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless9/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
@ -223,7 +223,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
default="pruned_transducer_stateless9/exp",
help="The experiment dir",
)

View File

@ -73,14 +73,50 @@ class Decoder(nn.Module):
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
self.repeat_param = nn.Parameter(torch.randn(decoder_dim))
def _add_repeat_param(
self,
embedding_out: torch.Tensor,
k: torch.Tensor,
is_training: bool = True,
) -> torch.Tensor:
"""
Add the repeat parameter to the embedding_out tensor.
Args:
embedding_out:
A tensor of shape (N, U, decoder_dim).
k:
A tensor of shape (N, U).
Should be (N, S + 1) during training.
Should be (N, 1) during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
return embedding_out + torch.matmul(
(k / (1 + k)).unsqueeze(2),
self.repeat_param.unsqueeze(0),
)
def forward(
self,
y: torch.Tensor,
k: torch.Tensor,
need_pad: bool = True,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
k:
A 2-D tensor, statistic given the context_size with respect to utt.
Should be (N, S + 1) during training.
Should be (N, 1) during inference.
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Whether to left pad the input.
Should be True during training.
Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
@ -90,7 +126,7 @@ class Decoder(nn.Module):
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
@ -98,5 +134,10 @@ class Decoder(nn.Module):
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self._add_repeat_param(
embedding_out=embedding_out,
k=k,
)
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -20,10 +20,11 @@ import random
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
from icefall.utils import add_sos
from icefall.utils import add_sos, make_pad_mask
class Transducer(nn.Module):
@ -75,7 +76,9 @@ class Transducer(nn.Module):
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
y: torch.Tensor,
y_lens: torch.Tensor,
k: torch.Tensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
@ -88,8 +91,14 @@ class Transducer(nn.Module):
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
A 2-D tensor with 2 axes [utt][label]. It contains labels of each
utterance.
y_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `y`
before padding.
k:
A statistic given the context_size with respect to utt.
A 2-D tensor of shape (N, U).
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
@ -110,31 +119,24 @@ class Transducer(nn.Module):
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert len(y.shape) == 2, len(y.shape)
assert x.size(0) == x_lens.size(0) == y.dim0
assert x.size(0) == x_lens.size(0) == y.size(0)
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
decoder_out = self.decoder(sos_y_padded, k)
# Note: y does not start with SOS
# Note: y_padded does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
y_padded = y.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

View File

@ -22,22 +22,22 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
./pruned_transducer_stateless9/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--exp-dir pruned_transducer_stateless9/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless7/train.py \
./pruned_transducer_stateless9/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--exp-dir pruned_transducer_stateless9/exp \
--full-libri 1 \
--max-duration 550
@ -58,6 +58,7 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
@ -69,6 +70,7 @@ from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -235,7 +237,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
default="pruned_transducer_stateless9/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -625,6 +627,41 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def compute_k(
y: torch.Tensor,
context_size: int = 2,
blank_id: int = 0,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
Returns:
Return a tensor of shape (N, U).
"""
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
k = torch.zeros_like(y)
for i in range(2, y.size(1) - 1):
k[:, i : i + 1] = torch.where(
y[:, i : i + 1] != 0,
torch.sum(
(
y[:, i - context_size : i]
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
).int(),
dim=1,
keepdim=True,
),
y[:, i : i + 1],
)
return k
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -675,13 +712,26 @@ def compute_loss(
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
y_lens = torch.tensor(list(map(len, y))).to(device)
y = list(map(torch.tensor, y))
y = pad_sequence(y, batch_first=True) # [B, S]
k = compute_k(
y,
params.context_size,
model.module.decoder.blank_id
if isinstance(model, DDP)
else model.decoder.blank_id,
).to(device)
y = y.to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
y_lens=y_lens,
k=k,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,