mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Initial tcpgen biasing
This commit is contained in:
parent
1f5c0a87b9
commit
cfb0cab3eb
284
egs/librispeech/ASR/zipformer/biasing.py
Normal file
284
egs/librispeech/ASR/zipformer/biasing.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from icefall import ContextGraph
|
||||
import logging
|
||||
import random
|
||||
|
||||
|
||||
class TCPGen(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
embed_dim: int,
|
||||
joiner_dim: int,
|
||||
decoder: nn.Module,
|
||||
attn_dim: int = 512,
|
||||
tcpgen_dropout: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
# embedding for ool (out of biasing list) token
|
||||
self.oolemb = torch.nn.Embedding(1, embed_dim)
|
||||
# project symbol embeddings
|
||||
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
|
||||
# project encoder embeddings
|
||||
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
|
||||
# project symbol embeddings (vocabulary + ool)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
|
||||
# generate tcpgen probability
|
||||
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
|
||||
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
|
||||
self.decoder = decoder
|
||||
self.vocab_size = decoder.vocab_size
|
||||
|
||||
def get_tcpgen_masks(
|
||||
self,
|
||||
targets: torch.Tensor,
|
||||
context_graph: ContextGraph,
|
||||
vocab_size: int,
|
||||
blank_id: int = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, sql_len = targets.shape
|
||||
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
|
||||
gen_masks = []
|
||||
yseqs = targets.tolist()
|
||||
for i, yseq in enumerate(yseqs):
|
||||
node = context_graph.root
|
||||
gen_mask = []
|
||||
for j, y in enumerate(yseq):
|
||||
not_matched = False
|
||||
if y == blank_id:
|
||||
node = context_graph.root
|
||||
gen_mask.append(0)
|
||||
elif y in node.next:
|
||||
gen_mask.append(0)
|
||||
node = node.next[y]
|
||||
if node.is_end:
|
||||
node = context_graph.root
|
||||
else:
|
||||
gen_mask.append(1)
|
||||
node = context_graph.root
|
||||
not_matched = True
|
||||
# unmask_index = (
|
||||
# [vocab_size]
|
||||
# if node.token == -1
|
||||
# else list(node.next.keys()) + [vocab_size]
|
||||
# )
|
||||
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
|
||||
# dist_masks[i, j, unmask_index] = 0
|
||||
if not not_matched:
|
||||
dist_masks[i, j, list(node.next.keys())] = 0
|
||||
|
||||
gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
|
||||
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
|
||||
dist_masks = dist_masks.to(targets.device).bool()
|
||||
if random.random() >= 0.95:
|
||||
logging.info(
|
||||
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
|
||||
)
|
||||
logging.info(
|
||||
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
|
||||
)
|
||||
return dist_masks, gen_masks
|
||||
|
||||
def get_tcpgen_distribution(
|
||||
self, query: torch.Tensor, dist_masks: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
query:
|
||||
shape : (B, T, s_range, attn_dim)
|
||||
dist_masks:
|
||||
shape : (B, T, s_range, V + 1)
|
||||
"""
|
||||
# From original paper, k, v share same embeddings
|
||||
# (V + 1, embed_dim)
|
||||
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
|
||||
# (V + 1, attn_dim)
|
||||
kv = self.dropout_tcpgen(self.k_proj(kv))
|
||||
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
|
||||
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
|
||||
distribution = distribution.masked_fill(
|
||||
dist_masks, torch.finfo(distribution.dtype).min
|
||||
)
|
||||
distribution = distribution.softmax(dim=-1)
|
||||
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
|
||||
# logging.info(f"distribution shape : {distribution.shape}")
|
||||
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
|
||||
hptr = self.dropout_tcpgen(hptr)
|
||||
if random.random() > 0.95:
|
||||
logging.info(
|
||||
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
|
||||
)
|
||||
logging.info(
|
||||
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
|
||||
)
|
||||
return hptr, distribution
|
||||
|
||||
def prune_query_and_mask(
|
||||
self,
|
||||
query_sym: torch.Tensor,
|
||||
query_acoustic: torch.Tensor,
|
||||
dist_masks: torch.Tensor,
|
||||
gen_masks: torch.Tensor,
|
||||
ranges: torch.Tensor,
|
||||
):
|
||||
"""Prune the queries from symbols and acoustics with ranges
|
||||
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.
|
||||
|
||||
Args:
|
||||
query_sym:
|
||||
The symbol query, with shape (B, S, attn_dim).
|
||||
query_acoustic:
|
||||
The acoustic query, with shape (B, T, attn_dim).
|
||||
dist_masks:
|
||||
The TCPGen distribution masks, with shape (B, S, V + 1).
|
||||
gen_masks:
|
||||
The TCPGen probability masks, with shape (B, S).
|
||||
ranges:
|
||||
A tensor containing the symbol indexes for each frame that we want to
|
||||
keep. Its shape is (B, T, s_range), see the docs in
|
||||
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.
|
||||
|
||||
Returns:
|
||||
Return the pruned query with the shape (B, T, s_range, attn_dim).
|
||||
"""
|
||||
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
|
||||
assert ranges.shape[0] == query_acoustic.shape[0], (
|
||||
ranges.shape,
|
||||
query_acoustic.shape,
|
||||
)
|
||||
assert query_acoustic.shape[1] == ranges.shape[1], (
|
||||
query_acoustic.shape,
|
||||
ranges.shape,
|
||||
)
|
||||
(B, T, s_range) = ranges.shape
|
||||
(B, S, attn_dim) = query_sym.shape
|
||||
assert query_acoustic.shape == (B, T, attn_dim), (
|
||||
query_acoustic.shape,
|
||||
(B, T, attn_dim),
|
||||
)
|
||||
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
|
||||
dist_masks.shape,
|
||||
(B, S, self.vocab_size + 1),
|
||||
)
|
||||
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
|
||||
# (B, T, s_range, attn_dim)
|
||||
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
|
||||
(B, T, s_range, attn_dim)
|
||||
)
|
||||
# logging.info(f"query_sym : {query_sym.shape}")
|
||||
# logging.info(f"ranges : {ranges}")
|
||||
# (B, T, s_range, attn_dim)
|
||||
query_sym_pruned = torch.gather(
|
||||
query_sym,
|
||||
dim=1,
|
||||
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
|
||||
).reshape(B, T, s_range, attn_dim)
|
||||
# (B, T, s_range, V + 1)
|
||||
dist_masks_pruned = torch.gather(
|
||||
dist_masks,
|
||||
dim=1,
|
||||
index=ranges.reshape(B, T * s_range, 1).expand(
|
||||
(B, T * s_range, self.vocab_size + 1)
|
||||
),
|
||||
).reshape(B, T, s_range, self.vocab_size + 1)
|
||||
# (B, T, s_range)
|
||||
gen_masks_pruned = torch.gather(
|
||||
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
|
||||
).reshape(B, T, s_range)
|
||||
return (
|
||||
query_sym_pruned + query_acoustic_pruned,
|
||||
dist_masks_pruned,
|
||||
gen_masks_pruned,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
targets: torch.Tensor,
|
||||
encoder_embeddings: torch.Tensor,
|
||||
ranges: torch.Tensor,
|
||||
context_graph: ContextGraph,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
target:
|
||||
The training targets in token ids (padded with blanks). shape : (B, S)
|
||||
encoder_embeddings:
|
||||
The encoder outputs. shape: (B, T, attn_dim)
|
||||
ranges:
|
||||
The prune ranges from pruned rnnt. shape: (B, T, s_range)
|
||||
context_graphs:
|
||||
The context_graphs for each utterance. B == len(context_graphs).
|
||||
|
||||
Return:
|
||||
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
|
||||
tcpgen distribution with shape (B, T, s_range, V + 1).
|
||||
"""
|
||||
query_sym = self.decoder.embedding(targets)
|
||||
|
||||
query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
|
||||
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)
|
||||
|
||||
# dist_masks : (B, S, V + 1)
|
||||
# gen_masks : (B, S)
|
||||
dist_masks, gen_masks = self.get_tcpgen_masks(
|
||||
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
|
||||
)
|
||||
# query : (B, T, s_range, attn_dim)
|
||||
# dist_masks : (B, T, s_range, V + 1)
|
||||
query, dist_masks, gen_masks = self.prune_query_and_mask(
|
||||
query_sym=query_sym,
|
||||
query_acoustic=query_acoustic,
|
||||
dist_masks=dist_masks,
|
||||
gen_masks=gen_masks,
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
if random.random() >= 0.95:
|
||||
logging.info(
|
||||
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
|
||||
)
|
||||
logging.info(
|
||||
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
|
||||
)
|
||||
|
||||
# hptr : (B, T, s_range, attn_dim)
|
||||
# tcpgen_dist : (B, T, s_range, V + 1)
|
||||
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
|
||||
return hptr, tcpgen_dist, gen_masks
|
||||
|
||||
def generator_prob(
|
||||
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
# tcpgen_prob : (B, T, s_range, 1)
|
||||
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
|
||||
tcpgen_prob = torch.sigmoid(tcpgen_prob)
|
||||
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
|
||||
if random.random() >= 0.95:
|
||||
logging.info(
|
||||
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
|
||||
)
|
||||
logging.info(
|
||||
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
|
||||
)
|
||||
return tcpgen_prob
|
@ -17,6 +17,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ScaledLinear
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
@ -37,6 +38,7 @@ class Joiner(nn.Module):
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
tcpgen_hptr: Optional[torch.Tensor] = None,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -62,6 +64,11 @@ class Joiner(nn.Module):
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
if tcpgen_hptr is not None:
|
||||
logit += tcpgen_hptr
|
||||
|
||||
return logit
|
||||
activations = torch.tanh(logit)
|
||||
|
||||
logit = self.output_linear(activations)
|
||||
|
||||
return logit if tcpgen_hptr is None else (logit, activations)
|
||||
|
@ -17,14 +17,17 @@
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
from biasing import TCPGen
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from icefall import ContextGraph
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
@ -36,9 +39,13 @@ class AsrModel(nn.Module):
|
||||
joiner: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
joiner_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
tcpgen_attn_dim: int = 512,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
use_tcpgen_biasing: bool = False,
|
||||
tcpgen_dropout: float = 0.15,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
@ -111,6 +118,19 @@ class AsrModel(nn.Module):
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
self.use_tcpgen_biasing = use_tcpgen_biasing
|
||||
if use_tcpgen_biasing:
|
||||
assert use_transducer, "TCPGen biasing only support on transducer model."
|
||||
self.tcp_gen = TCPGen(
|
||||
encoder_dim=encoder_dim,
|
||||
embed_dim=decoder_dim,
|
||||
attn_dim=tcpgen_attn_dim,
|
||||
joiner_dim=joiner_dim,
|
||||
decoder=decoder,
|
||||
tcpgen_dropout=tcpgen_dropout,
|
||||
)
|
||||
self.hptr_proj = torch.nn.Linear(tcpgen_attn_dim, joiner_dim)
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -180,6 +200,7 @@ class AsrModel(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
@ -202,6 +223,7 @@ class AsrModel(nn.Module):
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
assert blank_id == 0, f"Assuming blank_id is 0, given {blank_id}"
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
@ -226,11 +248,6 @@ class AsrModel(nn.Module):
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
@ -252,19 +269,66 @@ class AsrModel(nn.Module):
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
# am_pruned : [B, T, prune_range, joiner_dim]
|
||||
# lm_pruned : [B, T, prune_range, joiner_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
fused_log_softmax = True
|
||||
if self.use_tcpgen_biasing and context_graph is not None:
|
||||
# logging.info(f"using tcpgen baising")
|
||||
fused_log_softmax = False
|
||||
# hptr : (B, T, s_range, attn_dim)
|
||||
# tcpgen_dist : (B, T, s_range, V + 1)
|
||||
# gen_masks : (B, T, s_range)
|
||||
hptr, tcpgen_dist, gen_masks = self.tcp_gen(
|
||||
targets=sos_y_padded,
|
||||
encoder_embeddings=encoder_out,
|
||||
ranges=ranges,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
# (B, T, s_range, joiner_dim)
|
||||
tcpgen_hptr = self.hptr_proj(hptr)
|
||||
# logits : (B, T, s_range, V)
|
||||
# activations : (B, T, s_range, joiner_dim)
|
||||
logits, activations = self.joiner(
|
||||
encoder_out=am_pruned,
|
||||
decoder_out=lm_pruned,
|
||||
tcpgen_hptr=tcpgen_hptr,
|
||||
project_input=False,
|
||||
)
|
||||
# (B, T, s_range, 1)
|
||||
tcpgen_prob = self.tcp_gen.generator_prob(
|
||||
hptr=hptr, h_joiner=activations, gen_masks=gen_masks
|
||||
)
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
# Assuming blank_id is 0
|
||||
p_mdl = torch.softmax(logits, dim=-1)
|
||||
p_no_blank = 1.0 - p_mdl[:, :, :, 0:1]
|
||||
|
||||
# blank dist (tcpgen_dist[:,:,:,0]) should be 0
|
||||
scaled_tcpgen_dist = tcpgen_dist[:, :, :, 1:] * p_no_blank
|
||||
|
||||
# ool probability : tcpgen_dist[:, :, :, -1:]
|
||||
scaled_tcpgen_prob = tcpgen_prob * (1 - tcpgen_dist[:, :, :, -1:])
|
||||
|
||||
interpolated_no_blank = scaled_tcpgen_dist[
|
||||
:, :, :, :-1
|
||||
] * tcpgen_prob + p_mdl[:, :, :, 1:] * (1 - scaled_tcpgen_prob)
|
||||
|
||||
final_dist = torch.cat([p_mdl[:, :, :, 0:1], interpolated_no_blank], dim=-1)
|
||||
|
||||
# actually logprobs
|
||||
logits = torch.log(final_dist + torch.finfo(final_dist.dtype).tiny)
|
||||
else:
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
@ -274,6 +338,7 @@ class AsrModel(nn.Module):
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
fused_log_softmax=fused_log_softmax,
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
@ -286,6 +351,7 @@ class AsrModel(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -338,6 +404,7 @@ class AsrModel(nn.Module):
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
|
@ -54,10 +54,11 @@ It supports training with:
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -81,7 +82,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall import ContextGraph, diagnostics
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
@ -259,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="If True, use CTC head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-tcpgen-biasing",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use tcpgen context biasing module",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -526,6 +534,7 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"epoch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
@ -534,6 +543,10 @@ def get_params() -> AttributeDict:
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
# parameters for tcpgen
|
||||
"tcpgen_start_epoch": 10,
|
||||
"num_distrators": 500,
|
||||
"distractors_list": [],
|
||||
}
|
||||
)
|
||||
|
||||
@ -628,9 +641,11 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
joiner=joiner,
|
||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
use_transducer=params.use_transducer,
|
||||
use_ctc=params.use_ctc,
|
||||
use_tcpgen_biasing=params.use_tcpgen_biasing,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -688,6 +703,7 @@ def load_checkpoint_if_available(
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"epoch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
@ -751,6 +767,40 @@ def save_checkpoint(
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def prepare_context_graph(
|
||||
params: AttributeDict,
|
||||
texts: List[str],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> ContextGraph:
|
||||
if params.epoch_idx_train == params.start_epoch:
|
||||
params.distractors_list += texts
|
||||
return None
|
||||
|
||||
logging.info(f"distractors_list : {len(params.distractors_list)}")
|
||||
if params.epoch_idx_train >= params.tcpgen_start_epoch:
|
||||
# logging.info("prepare context graph")
|
||||
contexts_list = []
|
||||
selected_texts = []
|
||||
for i, text in enumerate(texts):
|
||||
if random.random() >= 0.5:
|
||||
continue
|
||||
else:
|
||||
selected_texts.append(text)
|
||||
for i in range(params.num_distrators):
|
||||
index = random.randint(0, len(params.distractors_list) - 1)
|
||||
selected_texts.append(params.distractors_list[index])
|
||||
for st in selected_texts:
|
||||
word_list = st.split()
|
||||
start = random.randint(0, len(word_list) - 1)
|
||||
length = random.randint(1, 3)
|
||||
contexts_list.append(" ".join(word_list[start : start + length]))
|
||||
contexts_tokens = sp.encode(contexts_list, out_type=int)
|
||||
# logging.info(f"contexts_list : {contexts_list}, tokens : {contexts_tokens}")
|
||||
context_graph = ContextGraph(0)
|
||||
context_graph.build(contexts_tokens)
|
||||
return context_graph
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -792,6 +842,11 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
context_graph = None
|
||||
if params.use_tcpgen_biasing:
|
||||
context_graph = prepare_context_graph(params=params, texts=texts, sp=sp)
|
||||
# assert context_graph is not None
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
x=feature,
|
||||
@ -800,6 +855,7 @@ def compute_loss(
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
context_graph=context_graph,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
@ -939,6 +995,8 @@ def train_one_epoch(
|
||||
rank=0,
|
||||
)
|
||||
|
||||
logging.info(f"epoch_idx_train : {params.epoch_idx_train}")
|
||||
params.epoch_idx_train += 1
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
if batch_idx % 10 == 0:
|
||||
set_batch_count(model, get_adjusted_batch_count(params))
|
||||
@ -967,9 +1025,7 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
f"Caught exception: {e}."
|
||||
)
|
||||
logging.info(f"Caught exception: {e}.")
|
||||
save_bad_model()
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
@ -1177,16 +1233,16 @@ def run(rank, world_size, args):
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
# train_cuts = librispeech.train_all_shuf_cuts()
|
||||
|
||||
# previously we used the following code to load all training cuts,
|
||||
# strictly speaking, shuffled training cuts should be used instead,
|
||||
# but we leave the code here to demonstrate that there is an option
|
||||
# like this to combine multiple cutsets
|
||||
|
||||
# train_cuts = librispeech.train_clean_100_cuts()
|
||||
# train_cuts += librispeech.train_clean_360_cuts()
|
||||
# train_cuts += librispeech.train_other_500_cuts()
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
train_cuts += librispeech.train_clean_360_cuts()
|
||||
train_cuts += librispeech.train_other_500_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
|
||||
@ -1205,26 +1261,6 @@ def run(rank, world_size, args):
|
||||
# )
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
# and S is the number of tokens in the utterance
|
||||
|
||||
# In ./zipformer.py, the conv module uses the following expression
|
||||
# for subsampling
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
if T < len(tokens):
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. "
|
||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
f"Number of frames (after subsampling): {T}. "
|
||||
f"Text: {c.supervisions[0].text}. "
|
||||
f"Tokens: {tokens}. "
|
||||
f"Number of tokens: {len(tokens)}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
@ -1245,13 +1281,14 @@ def run(rank, world_size, args):
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
pass
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -1397,4 +1434,5 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_printoptions(profile="full")
|
||||
main()
|
||||
|
@ -162,6 +162,7 @@ class ContextGraph:
|
||||
phrases: Optional[List[str]] = None,
|
||||
scores: Optional[List[float]] = None,
|
||||
ac_thresholds: Optional[List[float]] = None,
|
||||
simple_trie: bool = False,
|
||||
):
|
||||
"""Build the ContextGraph from a list of token list.
|
||||
It first build a trie from the given token lists, then fill the fail arc
|
||||
@ -189,6 +190,11 @@ class ContextGraph:
|
||||
0 means using the default value (i.e. self.ac_threshold). It is
|
||||
used only when this graph applied for the keywords spotting system.
|
||||
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
|
||||
simple_trie:
|
||||
True for building only trie (i.e. no fail and output arcs). Needed by
|
||||
tcpgen biasing training.
|
||||
False for building a Aho-corasick automata, for hotword / keywords
|
||||
searching.
|
||||
|
||||
Note: The phrases would have shared states, the score of the shared states is
|
||||
the MAXIMUM value among all the tokens sharing this state.
|
||||
@ -211,7 +217,6 @@ class ContextGraph:
|
||||
context_score = self.context_score if score == 0.0 else score
|
||||
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
|
||||
for i, token in enumerate(tokens):
|
||||
node_next = {}
|
||||
if token not in node.next:
|
||||
self.num_nodes += 1
|
||||
is_end = i == len(tokens) - 1
|
||||
@ -240,7 +245,9 @@ class ContextGraph:
|
||||
node.next[token].phrase = phrase
|
||||
node.next[token].ac_threshold = threshold
|
||||
node = node.next[token]
|
||||
self._fill_fail_output()
|
||||
|
||||
if not simple_trie:
|
||||
self._fill_fail_output()
|
||||
|
||||
def forward_one_step(
|
||||
self, state: ContextState, token: int, strict_mode: bool = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user