modify the implement of computing k

This commit is contained in:
yifanyang 2023-02-07 19:33:59 +08:00
parent 06b583ee84
commit 6e377c6874
4 changed files with 68 additions and 104 deletions

View File

@ -588,18 +588,10 @@ def greedy_search(
1, context_size 1, context_size
) )
c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape( if hyp[-context_size - 1] == hyp[-1]:
1, context_size + 1 k += 1
) else:
k[0, 0] = 0
k = torch.sum(
(
c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
).int(),
dim=1,
keepdim=True,
)
decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
@ -712,21 +704,14 @@ def greedy_search_batch(
# update decoder output # update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]] decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
c = torch.tensor( k = torch.where(
[h[-context_size - 1 :] for h in hyps[:batch_size]], torch.tensor(
device=device, [h[-context_size - 1] for h in hyps[:batch_size]],
dtype=torch.int64, device=device,
) dtype=torch.int64,
) == decoder_input[:, -1],
k = torch.sum( k + 1,
( torch.zeros(N, 1, device=device, dtype=torch.int64),
c[:, :context_size]
== c[:, context_size : context_size + 1].expand_as(
c[:, :context_size]
)
).int(),
dim=1,
keepdim=True,
) )
decoder_input = torch.tensor( decoder_input = torch.tensor(

View File

@ -89,15 +89,17 @@ class Decoder(nn.Module):
A tensor of shape (N, U, decoder_dim). A tensor of shape (N, U, decoder_dim).
k: k:
A tensor of shape (N, U). A tensor of shape (N, U).
Should be (N, S + 1) during training. Should be (N, S) during training.
Should be (N, 1) during inference. Should be (N, 1) during inference.
is_training:
Whether it is training.
Returns: Returns:
Return a tensor of shape (N, U, decoder_dim). Return a tensor of shape (N, U, decoder_dim).
""" """
return embedding_out + torch.matmul( if is_training:
(k / (1 + k)).unsqueeze(2), k = F.pad(k, (1, 0), mode="constant", value=self.blank_id)
self.repeat_param.unsqueeze(0),
) return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
def forward( def forward(
self, self,
@ -138,6 +140,7 @@ class Decoder(nn.Module):
embedding_out = self._add_repeat_param( embedding_out = self._add_repeat_param(
embedding_out=embedding_out, embedding_out=embedding_out,
k=k, k=k,
is_training=need_pad,
) )
embedding_out = F.relu(embedding_out) embedding_out = F.relu(embedding_out)
return embedding_out return embedding_out

View File

@ -24,7 +24,7 @@ import torch.nn.functional as F
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt from scaling import penalize_abs_values_gt
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos
class Transducer(nn.Module): class Transducer(nn.Module):
@ -72,13 +72,35 @@ class Transducer(nn.Module):
) )
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
def _compute_k(
self,
y: torch.Tensor,
context_size: int = 2,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
Returns:
Return a tensor of shape (N, U).
"""
y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size]
mask = y_shift != y
T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device)
cummax_out = (T_arange * mask).cummax(dim=-1)[0]
k = T_arange - cummax_out
return k
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: torch.Tensor, y: k2.RaggedTensor,
y_lens: torch.Tensor,
k: torch.Tensor,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
@ -91,14 +113,8 @@ class Transducer(nn.Module):
A 1-D tensor of shape (N,). It contains the number of frames in `x` A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding. before padding.
y: y:
A 2-D tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
y_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `y`
before padding.
k:
A statistic given the context_size with respect to utt.
A 2-D tensor of shape (N, U).
prune_range: prune_range:
The prune range for rnnt loss, it means how many symbols(context) The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss. we are considering for each frame to compute the loss.
@ -119,24 +135,34 @@ class Transducer(nn.Module):
""" """
assert x.ndim == 3, x.shape assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
assert len(y.shape) == 2, len(y.shape) assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.size(0) assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens) encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS. # sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# compute k
k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size)
# decoder_out: [B, S + 1, decoder_dim] # decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded, k) decoder_out = self.decoder(sos_y_padded, k)
# Note: y_padded does not start with SOS # Note: y does not start with SOS
# y_padded : [B, S] # y_padded : [B, S]
y_padded = y.to(torch.int64) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens

View File

@ -39,7 +39,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless9/exp \ --exp-dir pruned_transducer_stateless9/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 750
""" """
@ -58,7 +58,6 @@ import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -70,7 +69,6 @@ from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer from zipformer import Zipformer
@ -237,7 +235,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless9/exp", default="pruned_transducer_stateless7/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -627,41 +625,6 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def compute_k(
y: torch.Tensor,
context_size: int = 2,
blank_id: int = 0,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
Returns:
Return a tensor of shape (N, U).
"""
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
k = torch.zeros_like(y)
for i in range(2, y.size(1) - 1):
k[:, i : i + 1] = torch.where(
y[:, i : i + 1] != 0,
torch.sum(
(
y[:, i - context_size : i]
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
).int(),
dim=1,
keepdim=True,
),
y[:, i : i + 1],
)
return k
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -712,26 +675,13 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y_lens = torch.tensor(list(map(len, y))).to(device) y = k2.RaggedTensor(y).to(device)
y = list(map(torch.tensor, y))
y = pad_sequence(y, batch_first=True) # [B, S]
k = compute_k(
y,
params.context_size,
model.module.decoder.blank_id
if isinstance(model, DDP)
else model.decoder.blank_id,
).to(device)
y = y.to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
y_lens=y_lens,
k=k,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
@ -1093,10 +1043,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri: if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts() train_cuts = librispeech.train_all_shuf_cuts()
train_cuts += librispeech.train_other_500_cuts() else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds