modify the implement of computing k

This commit is contained in:
yifanyang 2023-02-07 19:33:59 +08:00
parent 06b583ee84
commit 6e377c6874
4 changed files with 68 additions and 104 deletions

View File

@ -588,18 +588,10 @@ def greedy_search(
1, context_size
)
c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape(
1, context_size + 1
)
k = torch.sum(
(
c[:, -context_size - 1 : -1]
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
).int(),
dim=1,
keepdim=True,
)
if hyp[-context_size - 1] == hyp[-1]:
k += 1
else:
k[0, 0] = 0
decoder_out = model.decoder(decoder_input, k, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
@ -712,21 +704,14 @@ def greedy_search_batch(
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
c = torch.tensor(
[h[-context_size - 1 :] for h in hyps[:batch_size]],
device=device,
dtype=torch.int64,
)
k = torch.sum(
(
c[:, :context_size]
== c[:, context_size : context_size + 1].expand_as(
c[:, :context_size]
)
).int(),
dim=1,
keepdim=True,
k = torch.where(
torch.tensor(
[h[-context_size - 1] for h in hyps[:batch_size]],
device=device,
dtype=torch.int64,
) == decoder_input[:, -1],
k + 1,
torch.zeros(N, 1, device=device, dtype=torch.int64),
)
decoder_input = torch.tensor(

View File

@ -89,15 +89,17 @@ class Decoder(nn.Module):
A tensor of shape (N, U, decoder_dim).
k:
A tensor of shape (N, U).
Should be (N, S + 1) during training.
Should be (N, S) during training.
Should be (N, 1) during inference.
is_training:
Whether it is training.
Returns:
Return a tensor of shape (N, U, decoder_dim).
"""
return embedding_out + torch.matmul(
(k / (1 + k)).unsqueeze(2),
self.repeat_param.unsqueeze(0),
)
if is_training:
k = F.pad(k, (1, 0), mode="constant", value=self.blank_id)
return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
def forward(
self,
@ -138,6 +140,7 @@ class Decoder(nn.Module):
embedding_out = self._add_repeat_param(
embedding_out=embedding_out,
k=k,
is_training=need_pad,
)
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -24,7 +24,7 @@ import torch.nn.functional as F
from encoder_interface import EncoderInterface
from scaling import penalize_abs_values_gt
from icefall.utils import add_sos, make_pad_mask
from icefall.utils import add_sos
class Transducer(nn.Module):
@ -72,13 +72,35 @@ class Transducer(nn.Module):
)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
def _compute_k(
self,
y: torch.Tensor,
context_size: int = 2,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
Returns:
Return a tensor of shape (N, U).
"""
y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size]
mask = y_shift != y
T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device)
cummax_out = (T_arange * mask).cummax(dim=-1)[0]
k = T_arange - cummax_out
return k
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: torch.Tensor,
y_lens: torch.Tensor,
k: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
@ -91,14 +113,8 @@ class Transducer(nn.Module):
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A 2-D tensor with 2 axes [utt][label]. It contains labels of each
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
y_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `y`
before padding.
k:
A statistic given the context_size with respect to utt.
A 2-D tensor of shape (N, U).
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
@ -119,24 +135,34 @@ class Transducer(nn.Module):
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert len(y.shape) == 2, len(y.shape)
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.size(0)
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# compute k
k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded, k)
# Note: y_padded does not start with SOS
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.to(torch.int64)
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens

View File

@ -39,7 +39,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--use-fp16 1 \
--exp-dir pruned_transducer_stateless9/exp \
--full-libri 1 \
--max-duration 550
--max-duration 750
"""
@ -58,7 +58,6 @@ import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from joiner import Joiner
@ -70,7 +69,6 @@ from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
@ -237,7 +235,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless9/exp",
default="pruned_transducer_stateless7/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -627,41 +625,6 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)
def compute_k(
y: torch.Tensor,
context_size: int = 2,
blank_id: int = 0,
) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
Returns:
Return a tensor of shape (N, U).
"""
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
k = torch.zeros_like(y)
for i in range(2, y.size(1) - 1):
k[:, i : i + 1] = torch.where(
y[:, i : i + 1] != 0,
torch.sum(
(
y[:, i - context_size : i]
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
).int(),
dim=1,
keepdim=True,
),
y[:, i : i + 1],
)
return k
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -712,26 +675,13 @@ def compute_loss(
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y_lens = torch.tensor(list(map(len, y))).to(device)
y = list(map(torch.tensor, y))
y = pad_sequence(y, batch_first=True) # [B, S]
k = compute_k(
y,
params.context_size,
model.module.decoder.blank_id
if isinstance(model, DDP)
else model.decoder.blank_id,
).to(device)
y = y.to(device)
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
y_lens=y_lens,
k=k,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
@ -1093,10 +1043,10 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds