mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
modify the implement of computing k
This commit is contained in:
parent
06b583ee84
commit
6e377c6874
@ -588,18 +588,10 @@ def greedy_search(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape(
|
||||
1, context_size + 1
|
||||
)
|
||||
|
||||
k = torch.sum(
|
||||
(
|
||||
c[:, -context_size - 1 : -1]
|
||||
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
|
||||
).int(),
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
)
|
||||
if hyp[-context_size - 1] == hyp[-1]:
|
||||
k += 1
|
||||
else:
|
||||
k[0, 0] = 0
|
||||
|
||||
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
@ -712,21 +704,14 @@ def greedy_search_batch(
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
|
||||
c = torch.tensor(
|
||||
[h[-context_size - 1 :] for h in hyps[:batch_size]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
k = torch.sum(
|
||||
(
|
||||
c[:, :context_size]
|
||||
== c[:, context_size : context_size + 1].expand_as(
|
||||
c[:, :context_size]
|
||||
)
|
||||
).int(),
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
k = torch.where(
|
||||
torch.tensor(
|
||||
[h[-context_size - 1] for h in hyps[:batch_size]],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) == decoder_input[:, -1],
|
||||
k + 1,
|
||||
torch.zeros(N, 1, device=device, dtype=torch.int64),
|
||||
)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
|
@ -89,15 +89,17 @@ class Decoder(nn.Module):
|
||||
A tensor of shape (N, U, decoder_dim).
|
||||
k:
|
||||
A tensor of shape (N, U).
|
||||
Should be (N, S + 1) during training.
|
||||
Should be (N, S) during training.
|
||||
Should be (N, 1) during inference.
|
||||
is_training:
|
||||
Whether it is training.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
return embedding_out + torch.matmul(
|
||||
(k / (1 + k)).unsqueeze(2),
|
||||
self.repeat_param.unsqueeze(0),
|
||||
)
|
||||
if is_training:
|
||||
k = F.pad(k, (1, 0), mode="constant", value=self.blank_id)
|
||||
|
||||
return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -138,6 +140,7 @@ class Decoder(nn.Module):
|
||||
embedding_out = self._add_repeat_param(
|
||||
embedding_out=embedding_out,
|
||||
k=k,
|
||||
is_training=need_pad,
|
||||
)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
|
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -72,13 +72,35 @@ class Transducer(nn.Module):
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
|
||||
def _compute_k(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
context_size: int = 2,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U).
|
||||
"""
|
||||
y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size]
|
||||
mask = y_shift != y
|
||||
|
||||
T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device)
|
||||
cummax_out = (T_arange * mask).cummax(dim=-1)[0]
|
||||
k = T_arange - cummax_out
|
||||
|
||||
return k
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
y_lens: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
@ -91,14 +113,8 @@ class Transducer(nn.Module):
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A 2-D tensor with 2 axes [utt][label]. It contains labels of each
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
y_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `y`
|
||||
before padding.
|
||||
k:
|
||||
A statistic given the context_size with respect to utt.
|
||||
A 2-D tensor of shape (N, U).
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
@ -119,24 +135,34 @@ class Transducer(nn.Module):
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert len(y.shape) == 2, len(y.shape)
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.size(0)
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id)
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# compute k
|
||||
k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded, k)
|
||||
|
||||
# Note: y_padded does not start with SOS
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.to(torch.int64)
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
@ -39,7 +39,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless9/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
--max-duration 750
|
||||
|
||||
"""
|
||||
|
||||
@ -58,7 +58,6 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -70,7 +69,6 @@ from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
|
||||
@ -237,7 +235,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless9/exp",
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -627,41 +625,6 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_k(
|
||||
y: torch.Tensor,
|
||||
context_size: int = 2,
|
||||
blank_id: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U).
|
||||
"""
|
||||
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
|
||||
k = torch.zeros_like(y)
|
||||
|
||||
for i in range(2, y.size(1) - 1):
|
||||
k[:, i : i + 1] = torch.where(
|
||||
y[:, i : i + 1] != 0,
|
||||
torch.sum(
|
||||
(
|
||||
y[:, i - context_size : i]
|
||||
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
|
||||
).int(),
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
),
|
||||
y[:, i : i + 1],
|
||||
)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -712,26 +675,13 @@ def compute_loss(
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y_lens = torch.tensor(list(map(len, y))).to(device)
|
||||
y = list(map(torch.tensor, y))
|
||||
y = pad_sequence(y, batch_first=True) # [B, S]
|
||||
|
||||
k = compute_k(
|
||||
y,
|
||||
params.context_size,
|
||||
model.module.decoder.blank_id
|
||||
if isinstance(model, DDP)
|
||||
else model.decoder.blank_id,
|
||||
).to(device)
|
||||
y = y.to(device)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
y_lens=y_lens,
|
||||
k=k,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
@ -1093,10 +1043,10 @@ def run(rank, world_size, args):
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
if params.full_libri:
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
|
Loading…
x
Reference in New Issue
Block a user