mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Update aishell with k2 pruned rnnt loss
This commit is contained in:
parent
f94ff19bfe
commit
e46409e90f
@ -73,9 +73,9 @@ def greedy_search(
|
||||
continue
|
||||
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
|
||||
# fmt: on
|
||||
logits = model.joiner(current_encoder_out, decoder_out)
|
||||
logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
|
||||
# logits is (1, 1, 1, vocab_size)
|
||||
|
||||
y = logits.argmax().item()
|
||||
|
@ -128,7 +128,7 @@ def get_params() -> AttributeDict:
|
||||
{
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"embedding_dim": 256,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
@ -145,7 +145,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
@ -159,7 +159,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
embedding_dim=params.embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
@ -168,8 +168,9 @@ def get_decoder_model(params: AttributeDict):
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
input_dim=params.vocab_size,
|
||||
output_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
)
|
||||
return joiner
|
||||
|
||||
@ -408,7 +409,7 @@ def main():
|
||||
device=device,
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
@ -70,6 +70,7 @@ class Decoder(nn.Module):
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.output_linear = nn.Linear(embedding_dim, vocab_size)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
@ -80,7 +81,7 @@ class Decoder(nn.Module):
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
Return a tensor of shape (N, U, vocab_size).
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
if self.context_size > 1:
|
||||
@ -95,4 +96,5 @@ class Decoder(nn.Module):
|
||||
assert embeding_out.size(-1) == self.context_size
|
||||
embeding_out = self.conv(embeding_out)
|
||||
embeding_out = embeding_out.permute(0, 2, 1)
|
||||
embeding_out = self.output_linear(embeding_out)
|
||||
return embeding_out
|
||||
|
@ -19,10 +19,12 @@ import torch.nn as nn
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||
self.output_linear = nn.Sequential(
|
||||
nn.Linear(input_dim, inner_dim), nn.Linear(inner_dim, output_dim)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
@ -36,15 +38,8 @@ class Joiner(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, U, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == decoder_out.size(2)
|
||||
|
||||
encoder_out = encoder_out.unsqueeze(2)
|
||||
# Now encoder_out is (N, T, 1, C)
|
||||
|
||||
decoder_out = decoder_out.unsqueeze(1)
|
||||
# Now decoder_out is (N, 1, U, C)
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = torch.tanh(logit)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -14,15 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
||||
"""
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
import torchaudio.functional
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
@ -38,6 +32,9 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
prune_range: int = 5,
|
||||
lm_scale: float = 0.0,
|
||||
am_scale: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -62,6 +59,9 @@ class Transducer(nn.Module):
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
self.prune_range = prune_range
|
||||
self.lm_scale = lm_scale
|
||||
self.am_scale = am_scale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -102,24 +102,38 @@ class Transducer(nn.Module):
|
||||
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
logits = self.joiner(encoder_out, decoder_out)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
# Note: y does not start with SOS
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
||||
"Please install a version >= 0.10.0"
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
decoder_out,
|
||||
encoder_out,
|
||||
y_padded,
|
||||
blank_id,
|
||||
lm_only_scale=self.lm_scale,
|
||||
am_only_scale=self.am_scale,
|
||||
boundary=boundary,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
loss = torchaudio.functional.rnnt_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
reduction="sum",
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad, py_grad, boundary, self.prune_range
|
||||
)
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
encoder_out, decoder_out, ranges
|
||||
)
|
||||
|
||||
return loss
|
||||
logits = self.joiner(am_pruned, lm_pruned)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits, y_padded, ranges, blank_id, boundary
|
||||
)
|
||||
|
||||
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
|
||||
|
@ -38,7 +38,6 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
@ -129,6 +128,28 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prune-range",
|
||||
type=int,
|
||||
default=5,
|
||||
help="The prune range for rnnt loss, it means how many symbols(context)"
|
||||
"we are using to compute the loss",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with lm (output of prediction network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -185,18 +206,19 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"valid_interval": 800,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for decoder
|
||||
"embedding_dim": 256,
|
||||
# parameters for Noam
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"warm_step": 30000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
@ -208,7 +230,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
@ -222,7 +244,7 @@ def get_encoder_model(params: AttributeDict):
|
||||
def get_decoder_model(params: AttributeDict):
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
embedding_dim=params.embedding_dim,
|
||||
blank_id=params.blank_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
@ -231,7 +253,8 @@ def get_decoder_model(params: AttributeDict):
|
||||
|
||||
def get_joiner_model(params: AttributeDict):
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
input_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
@ -246,6 +269,9 @@ def get_transducer_model(params: AttributeDict):
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
prune_range=params.prune_range,
|
||||
lm_scale=params.lm_scale,
|
||||
am_scale=params.am_scale,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -374,7 +400,8 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||
loss = simple_loss + pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -383,6 +410,8 @@ def compute_loss(
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -476,7 +505,6 @@ def train_one_epoch(
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
@ -555,10 +583,9 @@ def run(rank, world_size, args):
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
oov="<unk>",
|
||||
)
|
||||
|
||||
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0]
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user