mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Add repeat_param to embedding_out during training and decoding
This commit is contained in:
parent
45b68e1df6
commit
c3e01e141f
@ -542,7 +542,9 @@ def greedy_search(
|
||||
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
|
||||
).reshape(1, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
k = torch.zeros(1, 1, device=device, dtype=torch.int64)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -586,7 +588,20 @@ def greedy_search(
|
||||
1, context_size
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape(
|
||||
1, context_size + 1
|
||||
)
|
||||
|
||||
k[:, 0] = torch.sum(
|
||||
(
|
||||
c[:, -context_size - 1 : -1]
|
||||
== c[:, -1].expand_as(c[:, -context_size - 1 : -1])
|
||||
).int(),
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, k, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
sym_per_utt += 1
|
||||
|
@ -20,36 +20,36 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 8 \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
@ -57,10 +57,10 @@ Usage:
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
@ -70,10 +70,10 @@ Usage:
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
@ -83,10 +83,10 @@ Usage:
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless7/decode.py \
|
||||
./pruned_transducer_stateless9/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--exp-dir ./pruned_transducer_stateless9/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
@ -223,7 +223,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
default="pruned_transducer_stateless9/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -73,14 +73,50 @@ class Decoder(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
self.repeat_param = nn.Parameter(torch.randn(decoder_dim))
|
||||
|
||||
def _add_repeat_param(
|
||||
self,
|
||||
embedding_out: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
is_training: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add the repeat parameter to the embedding_out tensor.
|
||||
|
||||
Args:
|
||||
embedding_out:
|
||||
A tensor of shape (N, U, decoder_dim).
|
||||
k:
|
||||
A tensor of shape (N, U).
|
||||
Should be (N, S + 1) during training.
|
||||
Should be (N, 1) during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
return embedding_out + torch.matmul(
|
||||
(k / (1 + k)).unsqueeze(2),
|
||||
self.repeat_param.unsqueeze(0),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
need_pad: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
k:
|
||||
A 2-D tensor, statistic given the context_size with respect to utt.
|
||||
Should be (N, S + 1) during training.
|
||||
Should be (N, 1) during inference.
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Whether to left pad the input.
|
||||
Should be True during training.
|
||||
Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
@ -98,5 +134,10 @@ class Decoder(nn.Module):
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
|
||||
embedding_out = self._add_repeat_param(
|
||||
embedding_out=embedding_out,
|
||||
k=k,
|
||||
)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
|
@ -20,10 +20,11 @@ import random
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import penalize_abs_values_gt
|
||||
|
||||
from icefall.utils import add_sos
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
@ -75,7 +76,9 @@ class Transducer(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y: torch.Tensor,
|
||||
y_lens: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
@ -88,8 +91,14 @@ class Transducer(nn.Module):
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
A 2-D tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
y_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `y`
|
||||
before padding.
|
||||
k:
|
||||
A statistic given the context_size with respect to utt.
|
||||
A 2-D tensor of shape (N, U).
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
@ -110,31 +119,24 @@ class Transducer(nn.Module):
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
assert len(y.shape) == 2, len(y.shape)
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
assert x.size(0) == x_lens.size(0) == y.size(0)
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
decoder_out = self.decoder(sos_y_padded, k)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# Note: y_padded does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
y_padded = y.to(torch.int64)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
@ -22,22 +22,22 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
./pruned_transducer_stateless9/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--exp-dir pruned_transducer_stateless9/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
./pruned_transducer_stateless9/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--exp-dir pruned_transducer_stateless9/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
@ -58,6 +58,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
@ -69,6 +70,7 @@ from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer
|
||||
|
||||
@ -235,7 +237,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7/exp",
|
||||
default="pruned_transducer_stateless9/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -625,6 +627,41 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_k(
|
||||
y: torch.Tensor,
|
||||
context_size: int = 2,
|
||||
blank_id: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U).
|
||||
"""
|
||||
y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS.
|
||||
k = torch.zeros_like(y)
|
||||
|
||||
for i in range(2, y.size(1) - 1):
|
||||
k[:, i : i + 1] = torch.where(
|
||||
y[:, i : i + 1] != 0,
|
||||
torch.sum(
|
||||
(
|
||||
y[:, i - context_size : i]
|
||||
== y[:, i : i + 1].expand_as(y[:, i - context_size : i])
|
||||
).int(),
|
||||
dim=1,
|
||||
keepdim=True,
|
||||
),
|
||||
y[:, i : i + 1],
|
||||
)
|
||||
|
||||
return k
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -675,13 +712,26 @@ def compute_loss(
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
y_lens = torch.tensor(list(map(len, y))).to(device)
|
||||
y = list(map(torch.tensor, y))
|
||||
y = pad_sequence(y, batch_first=True) # [B, S]
|
||||
|
||||
k = compute_k(
|
||||
y,
|
||||
params.context_size,
|
||||
model.module.decoder.blank_id
|
||||
if isinstance(model, DDP)
|
||||
else model.decoder.blank_id,
|
||||
).to(device)
|
||||
y = y.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
y_lens=y_lens,
|
||||
k=k,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
|
Loading…
x
Reference in New Issue
Block a user