Merge cfb0cab3ebf719b393960ea3565361353f74eb28 into cbcac23d2617ccfdc8f1ecc14a00ba96413c3bf9

This commit is contained in:
Wei Kang 2024-07-04 14:49:49 +08:00 committed by GitHub
commit bacbffac11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 454 additions and 51 deletions

View File

@ -0,0 +1,284 @@
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
from icefall import ContextGraph
import logging
import random
class TCPGen(nn.Module):
def __init__(
self,
encoder_dim: int,
embed_dim: int,
joiner_dim: int,
decoder: nn.Module,
attn_dim: int = 512,
tcpgen_dropout: float = 0.1,
):
super().__init__()
# embedding for ool (out of biasing list) token
self.oolemb = torch.nn.Embedding(1, embed_dim)
# project symbol embeddings
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
# project encoder embeddings
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
# project symbol embeddings (vocabulary + ool)
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
# generate tcpgen probability
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
self.decoder = decoder
self.vocab_size = decoder.vocab_size
def get_tcpgen_masks(
self,
targets: torch.Tensor,
context_graph: ContextGraph,
vocab_size: int,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, sql_len = targets.shape
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
gen_masks = []
yseqs = targets.tolist()
for i, yseq in enumerate(yseqs):
node = context_graph.root
gen_mask = []
for j, y in enumerate(yseq):
not_matched = False
if y == blank_id:
node = context_graph.root
gen_mask.append(0)
elif y in node.next:
gen_mask.append(0)
node = node.next[y]
if node.is_end:
node = context_graph.root
else:
gen_mask.append(1)
node = context_graph.root
not_matched = True
# unmask_index = (
# [vocab_size]
# if node.token == -1
# else list(node.next.keys()) + [vocab_size]
# )
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
# dist_masks[i, j, unmask_index] = 0
if not not_matched:
dist_masks[i, j, list(node.next.keys())] = 0
gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
dist_masks = dist_masks.to(targets.device).bool()
if random.random() >= 0.95:
logging.info(
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
)
return dist_masks, gen_masks
def get_tcpgen_distribution(
self, query: torch.Tensor, dist_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query:
shape : (B, T, s_range, attn_dim)
dist_masks:
shape : (B, T, s_range, V + 1)
"""
# From original paper, k, v share same embeddings
# (V + 1, embed_dim)
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
# (V + 1, attn_dim)
kv = self.dropout_tcpgen(self.k_proj(kv))
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
distribution = distribution.masked_fill(
dist_masks, torch.finfo(distribution.dtype).min
)
distribution = distribution.softmax(dim=-1)
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
# logging.info(f"distribution shape : {distribution.shape}")
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
hptr = self.dropout_tcpgen(hptr)
if random.random() > 0.95:
logging.info(
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
)
logging.info(
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
)
return hptr, distribution
def prune_query_and_mask(
self,
query_sym: torch.Tensor,
query_acoustic: torch.Tensor,
dist_masks: torch.Tensor,
gen_masks: torch.Tensor,
ranges: torch.Tensor,
):
"""Prune the queries from symbols and acoustics with ranges
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.
Args:
query_sym:
The symbol query, with shape (B, S, attn_dim).
query_acoustic:
The acoustic query, with shape (B, T, attn_dim).
dist_masks:
The TCPGen distribution masks, with shape (B, S, V + 1).
gen_masks:
The TCPGen probability masks, with shape (B, S).
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.
Returns:
Return the pruned query with the shape (B, T, s_range, attn_dim).
"""
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
assert ranges.shape[0] == query_acoustic.shape[0], (
ranges.shape,
query_acoustic.shape,
)
assert query_acoustic.shape[1] == ranges.shape[1], (
query_acoustic.shape,
ranges.shape,
)
(B, T, s_range) = ranges.shape
(B, S, attn_dim) = query_sym.shape
assert query_acoustic.shape == (B, T, attn_dim), (
query_acoustic.shape,
(B, T, attn_dim),
)
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
dist_masks.shape,
(B, S, self.vocab_size + 1),
)
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
# (B, T, s_range, attn_dim)
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
(B, T, s_range, attn_dim)
)
# logging.info(f"query_sym : {query_sym.shape}")
# logging.info(f"ranges : {ranges}")
# (B, T, s_range, attn_dim)
query_sym_pruned = torch.gather(
query_sym,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
).reshape(B, T, s_range, attn_dim)
# (B, T, s_range, V + 1)
dist_masks_pruned = torch.gather(
dist_masks,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, self.vocab_size + 1)
),
).reshape(B, T, s_range, self.vocab_size + 1)
# (B, T, s_range)
gen_masks_pruned = torch.gather(
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
).reshape(B, T, s_range)
return (
query_sym_pruned + query_acoustic_pruned,
dist_masks_pruned,
gen_masks_pruned,
)
def forward(
self,
targets: torch.Tensor,
encoder_embeddings: torch.Tensor,
ranges: torch.Tensor,
context_graph: ContextGraph,
):
"""
Args:
target:
The training targets in token ids (padded with blanks). shape : (B, S)
encoder_embeddings:
The encoder outputs. shape: (B, T, attn_dim)
ranges:
The prune ranges from pruned rnnt. shape: (B, T, s_range)
context_graphs:
The context_graphs for each utterance. B == len(context_graphs).
Return:
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
tcpgen distribution with shape (B, T, s_range, V + 1).
"""
query_sym = self.decoder.embedding(targets)
query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)
# dist_masks : (B, S, V + 1)
# gen_masks : (B, S)
dist_masks, gen_masks = self.get_tcpgen_masks(
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
)
# query : (B, T, s_range, attn_dim)
# dist_masks : (B, T, s_range, V + 1)
query, dist_masks, gen_masks = self.prune_query_and_mask(
query_sym=query_sym,
query_acoustic=query_acoustic,
dist_masks=dist_masks,
gen_masks=gen_masks,
ranges=ranges,
)
if random.random() >= 0.95:
logging.info(
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
)
# hptr : (B, T, s_range, attn_dim)
# tcpgen_dist : (B, T, s_range, V + 1)
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
return hptr, tcpgen_dist, gen_masks
def generator_prob(
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
) -> torch.Tensor:
# tcpgen_prob : (B, T, s_range, 1)
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
tcpgen_prob = torch.sigmoid(tcpgen_prob)
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
if random.random() >= 0.95:
logging.info(
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
)
logging.info(
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
)
return tcpgen_prob

View File

@ -17,6 +17,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
from typing import Optional
class Joiner(nn.Module): class Joiner(nn.Module):
@ -37,6 +38,7 @@ class Joiner(nn.Module):
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
decoder_out: torch.Tensor, decoder_out: torch.Tensor,
tcpgen_hptr: Optional[torch.Tensor] = None,
project_input: bool = True, project_input: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -62,6 +64,11 @@ class Joiner(nn.Module):
else: else:
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = self.output_linear(torch.tanh(logit)) if tcpgen_hptr is not None:
logit += tcpgen_hptr
return logit activations = torch.tanh(logit)
logit = self.output_linear(activations)
return logit if tcpgen_hptr is None else (logit, activations)

View File

@ -17,14 +17,17 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple from typing import Optional, Tuple
import logging
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from biasing import TCPGen
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from icefall import ContextGraph
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -36,9 +39,13 @@ class AsrModel(nn.Module):
joiner: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None,
encoder_dim: int = 384, encoder_dim: int = 384,
decoder_dim: int = 512, decoder_dim: int = 512,
joiner_dim: int = 512,
vocab_size: int = 500, vocab_size: int = 500,
tcpgen_attn_dim: int = 512,
use_transducer: bool = True, use_transducer: bool = True,
use_ctc: bool = False, use_ctc: bool = False,
use_tcpgen_biasing: bool = False,
tcpgen_dropout: float = 0.15,
): ):
"""A joint CTC & Transducer ASR model. """A joint CTC & Transducer ASR model.
@ -111,6 +118,19 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1), nn.LogSoftmax(dim=-1),
) )
self.use_tcpgen_biasing = use_tcpgen_biasing
if use_tcpgen_biasing:
assert use_transducer, "TCPGen biasing only support on transducer model."
self.tcp_gen = TCPGen(
encoder_dim=encoder_dim,
embed_dim=decoder_dim,
attn_dim=tcpgen_attn_dim,
joiner_dim=joiner_dim,
decoder=decoder,
tcpgen_dropout=tcpgen_dropout,
)
self.hptr_proj = torch.nn.Linear(tcpgen_attn_dim, joiner_dim)
def forward_encoder( def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -180,6 +200,7 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
context_graph: Optional[ContextGraph] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss. """Compute Transducer loss.
Args: Args:
@ -202,6 +223,7 @@ class AsrModel(nn.Module):
""" """
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
assert blank_id == 0, f"Assuming blank_id is 0, given {blank_id}"
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS. # sos_y_padded: [B, S + 1], start with SOS.
@ -226,11 +248,6 @@ class AsrModel(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
@ -252,19 +269,66 @@ class AsrModel(nn.Module):
s_range=prune_range, s_range=prune_range,
) )
# am_pruned : [B, T, prune_range, encoder_dim] # am_pruned : [B, T, prune_range, joiner_dim]
# lm_pruned : [B, T, prune_range, decoder_dim] # lm_pruned : [B, T, prune_range, joiner_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out), am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out), lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges, ranges=ranges,
) )
# logits : [B, T, prune_range, vocab_size] fused_log_softmax = True
if self.use_tcpgen_biasing and context_graph is not None:
# logging.info(f"using tcpgen baising")
fused_log_softmax = False
# hptr : (B, T, s_range, attn_dim)
# tcpgen_dist : (B, T, s_range, V + 1)
# gen_masks : (B, T, s_range)
hptr, tcpgen_dist, gen_masks = self.tcp_gen(
targets=sos_y_padded,
encoder_embeddings=encoder_out,
ranges=ranges,
context_graph=context_graph,
)
# (B, T, s_range, joiner_dim)
tcpgen_hptr = self.hptr_proj(hptr)
# logits : (B, T, s_range, V)
# activations : (B, T, s_range, joiner_dim)
logits, activations = self.joiner(
encoder_out=am_pruned,
decoder_out=lm_pruned,
tcpgen_hptr=tcpgen_hptr,
project_input=False,
)
# (B, T, s_range, 1)
tcpgen_prob = self.tcp_gen.generator_prob(
hptr=hptr, h_joiner=activations, gen_masks=gen_masks
)
# project_input=False since we applied the decoder's input projections # Assuming blank_id is 0
# prior to do_rnnt_pruning (this is an optimization for speed). p_mdl = torch.softmax(logits, dim=-1)
logits = self.joiner(am_pruned, lm_pruned, project_input=False) p_no_blank = 1.0 - p_mdl[:, :, :, 0:1]
# blank dist (tcpgen_dist[:,:,:,0]) should be 0
scaled_tcpgen_dist = tcpgen_dist[:, :, :, 1:] * p_no_blank
# ool probability : tcpgen_dist[:, :, :, -1:]
scaled_tcpgen_prob = tcpgen_prob * (1 - tcpgen_dist[:, :, :, -1:])
interpolated_no_blank = scaled_tcpgen_dist[
:, :, :, :-1
] * tcpgen_prob + p_mdl[:, :, :, 1:] * (1 - scaled_tcpgen_prob)
final_dist = torch.cat([p_mdl[:, :, :, 0:1], interpolated_no_blank], dim=-1)
# actually logprobs
logits = torch.log(final_dist + torch.finfo(final_dist.dtype).tiny)
else:
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
@ -274,6 +338,7 @@ class AsrModel(nn.Module):
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
fused_log_softmax=fused_log_softmax,
) )
return simple_loss, pruned_loss return simple_loss, pruned_loss
@ -286,6 +351,7 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
context_graph: Optional[ContextGraph] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -338,6 +404,7 @@ class AsrModel(nn.Module):
prune_range=prune_range, prune_range=prune_range,
am_scale=am_scale, am_scale=am_scale,
lm_scale=lm_scale, lm_scale=lm_scale,
context_graph=context_graph,
) )
else: else:
simple_loss = torch.empty(0) simple_loss = torch.empty(0)

View File

@ -54,10 +54,11 @@ It supports training with:
import argparse import argparse
import copy import copy
import logging import logging
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import optim import optim
@ -81,7 +82,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall import diagnostics from icefall import ContextGraph, diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -259,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use CTC head.", help="If True, use CTC head.",
) )
parser.add_argument(
"--use-tcpgen-biasing",
type=str2bool,
default=False,
help="If True, use tcpgen context biasing module",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -522,6 +530,7 @@ def get_params() -> AttributeDict:
"best_train_epoch": -1, "best_train_epoch": -1,
"best_valid_epoch": -1, "best_valid_epoch": -1,
"batch_idx_train": 0, "batch_idx_train": 0,
"epoch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
@ -530,6 +539,10 @@ def get_params() -> AttributeDict:
"subsampling_factor": 4, # not passed in, this is fixed. "subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 2000, "warm_step": 2000,
"env_info": get_env_info(), "env_info": get_env_info(),
# parameters for tcpgen
"tcpgen_start_epoch": 10,
"num_distrators": 500,
"distractors_list": [],
} }
) )
@ -624,9 +637,11 @@ def get_model(params: AttributeDict) -> nn.Module:
joiner=joiner, joiner=joiner,
encoder_dim=max(_to_int_tuple(params.encoder_dim)), encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
use_transducer=params.use_transducer, use_transducer=params.use_transducer,
use_ctc=params.use_ctc, use_ctc=params.use_ctc,
use_tcpgen_biasing=params.use_tcpgen_biasing,
) )
return model return model
@ -684,6 +699,7 @@ def load_checkpoint_if_available(
"best_train_epoch", "best_train_epoch",
"best_valid_epoch", "best_valid_epoch",
"batch_idx_train", "batch_idx_train",
"epoch_idx_train",
"best_train_loss", "best_train_loss",
"best_valid_loss", "best_valid_loss",
] ]
@ -747,6 +763,40 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def prepare_context_graph(
params: AttributeDict,
texts: List[str],
sp: spm.SentencePieceProcessor,
) -> ContextGraph:
if params.epoch_idx_train == params.start_epoch:
params.distractors_list += texts
return None
logging.info(f"distractors_list : {len(params.distractors_list)}")
if params.epoch_idx_train >= params.tcpgen_start_epoch:
# logging.info("prepare context graph")
contexts_list = []
selected_texts = []
for i, text in enumerate(texts):
if random.random() >= 0.5:
continue
else:
selected_texts.append(text)
for i in range(params.num_distrators):
index = random.randint(0, len(params.distractors_list) - 1)
selected_texts.append(params.distractors_list[index])
for st in selected_texts:
word_list = st.split()
start = random.randint(0, len(word_list) - 1)
length = random.randint(1, 3)
contexts_list.append(" ".join(word_list[start : start + length]))
contexts_tokens = sp.encode(contexts_list, out_type=int)
# logging.info(f"contexts_list : {contexts_list}, tokens : {contexts_tokens}")
context_graph = ContextGraph(0)
context_graph.build(contexts_tokens)
return context_graph
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
@ -788,6 +838,11 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
context_graph = None
if params.use_tcpgen_biasing:
context_graph = prepare_context_graph(params=params, texts=texts, sp=sp)
# assert context_graph is not None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( simple_loss, pruned_loss, ctc_loss = model(
x=feature, x=feature,
@ -796,6 +851,7 @@ def compute_loss(
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
context_graph=context_graph,
) )
loss = 0.0 loss = 0.0
@ -935,6 +991,8 @@ def train_one_epoch(
rank=0, rank=0,
) )
logging.info(f"epoch_idx_train : {params.epoch_idx_train}")
params.epoch_idx_train += 1
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx % 10 == 0: if batch_idx % 10 == 0:
set_batch_count(model, get_adjusted_batch_count(params)) set_batch_count(model, get_adjusted_batch_count(params))
@ -963,9 +1021,7 @@ def train_one_epoch(
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except Exception as e: except Exception as e:
logging.info( logging.info(f"Caught exception: {e}.")
f"Caught exception: {e}."
)
save_bad_model() save_bad_model()
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise
@ -1173,16 +1229,16 @@ def run(rank, world_size, args):
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.full_libri: if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts() # train_cuts = librispeech.train_all_shuf_cuts()
# previously we used the following code to load all training cuts, # previously we used the following code to load all training cuts,
# strictly speaking, shuffled training cuts should be used instead, # strictly speaking, shuffled training cuts should be used instead,
# but we leave the code here to demonstrate that there is an option # but we leave the code here to demonstrate that there is an option
# like this to combine multiple cutsets # like this to combine multiple cutsets
# train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
# train_cuts += librispeech.train_clean_360_cuts() train_cuts += librispeech.train_clean_360_cuts()
# train_cuts += librispeech.train_other_500_cuts() train_cuts += librispeech.train_other_500_cuts()
else: else:
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
@ -1201,26 +1257,6 @@ def run(rank, world_size, args):
# ) # )
return False return False
# In pruned RNN-T, we require that T >= S
# where T is the number of feature frames after subsampling
# and S is the number of tokens in the utterance
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
f"Number of frames (after subsampling): {T}. "
f"Text: {c.supervisions[0].text}. "
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
return False
return True return True
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
@ -1241,13 +1277,14 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( pass
model=model, # scan_pessimistic_batches_for_oom(
train_dl=train_dl, # model=model,
optimizer=optimizer, # train_dl=train_dl,
sp=sp, # optimizer=optimizer,
params=params, # sp=sp,
) # params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -1393,4 +1430,5 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
torch.set_printoptions(profile="full")
main() main()

View File

@ -162,6 +162,7 @@ class ContextGraph:
phrases: Optional[List[str]] = None, phrases: Optional[List[str]] = None,
scores: Optional[List[float]] = None, scores: Optional[List[float]] = None,
ac_thresholds: Optional[List[float]] = None, ac_thresholds: Optional[List[float]] = None,
simple_trie: bool = False,
): ):
"""Build the ContextGraph from a list of token list. """Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc It first build a trie from the given token lists, then fill the fail arc
@ -189,6 +190,11 @@ class ContextGraph:
0 means using the default value (i.e. self.ac_threshold). It is 0 means using the default value (i.e. self.ac_threshold). It is
used only when this graph applied for the keywords spotting system. used only when this graph applied for the keywords spotting system.
The length of `ac_threshold` MUST be equal to the length of `token_ids`. The length of `ac_threshold` MUST be equal to the length of `token_ids`.
simple_trie:
True for building only trie (i.e. no fail and output arcs). Needed by
tcpgen biasing training.
False for building a Aho-corasick automata, for hotword / keywords
searching.
Note: The phrases would have shared states, the score of the shared states is Note: The phrases would have shared states, the score of the shared states is
the MAXIMUM value among all the tokens sharing this state. the MAXIMUM value among all the tokens sharing this state.
@ -211,7 +217,6 @@ class ContextGraph:
context_score = self.context_score if score == 0.0 else score context_score = self.context_score if score == 0.0 else score
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
node_next = {}
if token not in node.next: if token not in node.next:
self.num_nodes += 1 self.num_nodes += 1
is_end = i == len(tokens) - 1 is_end = i == len(tokens) - 1
@ -240,7 +245,9 @@ class ContextGraph:
node.next[token].phrase = phrase node.next[token].phrase = phrase
node.next[token].ac_threshold = threshold node.next[token].ac_threshold = threshold
node = node.next[token] node = node.next[token]
self._fill_fail_output()
if not simple_trie:
self._fill_fail_output()
def forward_one_step( def forward_one_step(
self, state: ContextState, token: int, strict_mode: bool = True self, state: ContextState, token: int, strict_mode: bool = True