mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Merge cfb0cab3ebf719b393960ea3565361353f74eb28 into cbcac23d2617ccfdc8f1ecc14a00ba96413c3bf9
This commit is contained in:
commit
bacbffac11
284
egs/librispeech/ASR/zipformer/biasing.py
Normal file
284
egs/librispeech/ASR/zipformer/biasing.py
Normal file
@ -0,0 +1,284 @@
|
|||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from icefall import ContextGraph
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class TCPGen(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
joiner_dim: int,
|
||||||
|
decoder: nn.Module,
|
||||||
|
attn_dim: int = 512,
|
||||||
|
tcpgen_dropout: float = 0.1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# embedding for ool (out of biasing list) token
|
||||||
|
self.oolemb = torch.nn.Embedding(1, embed_dim)
|
||||||
|
# project symbol embeddings
|
||||||
|
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
|
||||||
|
# project encoder embeddings
|
||||||
|
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
|
||||||
|
# project symbol embeddings (vocabulary + ool)
|
||||||
|
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
|
||||||
|
# generate tcpgen probability
|
||||||
|
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
|
||||||
|
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
|
||||||
|
self.decoder = decoder
|
||||||
|
self.vocab_size = decoder.vocab_size
|
||||||
|
|
||||||
|
def get_tcpgen_masks(
|
||||||
|
self,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
context_graph: ContextGraph,
|
||||||
|
vocab_size: int,
|
||||||
|
blank_id: int = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size, sql_len = targets.shape
|
||||||
|
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
|
||||||
|
gen_masks = []
|
||||||
|
yseqs = targets.tolist()
|
||||||
|
for i, yseq in enumerate(yseqs):
|
||||||
|
node = context_graph.root
|
||||||
|
gen_mask = []
|
||||||
|
for j, y in enumerate(yseq):
|
||||||
|
not_matched = False
|
||||||
|
if y == blank_id:
|
||||||
|
node = context_graph.root
|
||||||
|
gen_mask.append(0)
|
||||||
|
elif y in node.next:
|
||||||
|
gen_mask.append(0)
|
||||||
|
node = node.next[y]
|
||||||
|
if node.is_end:
|
||||||
|
node = context_graph.root
|
||||||
|
else:
|
||||||
|
gen_mask.append(1)
|
||||||
|
node = context_graph.root
|
||||||
|
not_matched = True
|
||||||
|
# unmask_index = (
|
||||||
|
# [vocab_size]
|
||||||
|
# if node.token == -1
|
||||||
|
# else list(node.next.keys()) + [vocab_size]
|
||||||
|
# )
|
||||||
|
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
|
||||||
|
# dist_masks[i, j, unmask_index] = 0
|
||||||
|
if not not_matched:
|
||||||
|
dist_masks[i, j, list(node.next.keys())] = 0
|
||||||
|
|
||||||
|
gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
|
||||||
|
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
|
||||||
|
dist_masks = dist_masks.to(targets.device).bool()
|
||||||
|
if random.random() >= 0.95:
|
||||||
|
logging.info(
|
||||||
|
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
|
||||||
|
)
|
||||||
|
return dist_masks, gen_masks
|
||||||
|
|
||||||
|
def get_tcpgen_distribution(
|
||||||
|
self, query: torch.Tensor, dist_masks: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
query:
|
||||||
|
shape : (B, T, s_range, attn_dim)
|
||||||
|
dist_masks:
|
||||||
|
shape : (B, T, s_range, V + 1)
|
||||||
|
"""
|
||||||
|
# From original paper, k, v share same embeddings
|
||||||
|
# (V + 1, embed_dim)
|
||||||
|
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
|
||||||
|
# (V + 1, attn_dim)
|
||||||
|
kv = self.dropout_tcpgen(self.k_proj(kv))
|
||||||
|
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
|
||||||
|
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
|
||||||
|
distribution = distribution.masked_fill(
|
||||||
|
dist_masks, torch.finfo(distribution.dtype).min
|
||||||
|
)
|
||||||
|
distribution = distribution.softmax(dim=-1)
|
||||||
|
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
|
||||||
|
# logging.info(f"distribution shape : {distribution.shape}")
|
||||||
|
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
|
||||||
|
hptr = self.dropout_tcpgen(hptr)
|
||||||
|
if random.random() > 0.95:
|
||||||
|
logging.info(
|
||||||
|
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
|
||||||
|
)
|
||||||
|
return hptr, distribution
|
||||||
|
|
||||||
|
def prune_query_and_mask(
|
||||||
|
self,
|
||||||
|
query_sym: torch.Tensor,
|
||||||
|
query_acoustic: torch.Tensor,
|
||||||
|
dist_masks: torch.Tensor,
|
||||||
|
gen_masks: torch.Tensor,
|
||||||
|
ranges: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Prune the queries from symbols and acoustics with ranges
|
||||||
|
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_sym:
|
||||||
|
The symbol query, with shape (B, S, attn_dim).
|
||||||
|
query_acoustic:
|
||||||
|
The acoustic query, with shape (B, T, attn_dim).
|
||||||
|
dist_masks:
|
||||||
|
The TCPGen distribution masks, with shape (B, S, V + 1).
|
||||||
|
gen_masks:
|
||||||
|
The TCPGen probability masks, with shape (B, S).
|
||||||
|
ranges:
|
||||||
|
A tensor containing the symbol indexes for each frame that we want to
|
||||||
|
keep. Its shape is (B, T, s_range), see the docs in
|
||||||
|
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the pruned query with the shape (B, T, s_range, attn_dim).
|
||||||
|
"""
|
||||||
|
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
|
||||||
|
assert ranges.shape[0] == query_acoustic.shape[0], (
|
||||||
|
ranges.shape,
|
||||||
|
query_acoustic.shape,
|
||||||
|
)
|
||||||
|
assert query_acoustic.shape[1] == ranges.shape[1], (
|
||||||
|
query_acoustic.shape,
|
||||||
|
ranges.shape,
|
||||||
|
)
|
||||||
|
(B, T, s_range) = ranges.shape
|
||||||
|
(B, S, attn_dim) = query_sym.shape
|
||||||
|
assert query_acoustic.shape == (B, T, attn_dim), (
|
||||||
|
query_acoustic.shape,
|
||||||
|
(B, T, attn_dim),
|
||||||
|
)
|
||||||
|
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
|
||||||
|
dist_masks.shape,
|
||||||
|
(B, S, self.vocab_size + 1),
|
||||||
|
)
|
||||||
|
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
|
||||||
|
# (B, T, s_range, attn_dim)
|
||||||
|
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
|
||||||
|
(B, T, s_range, attn_dim)
|
||||||
|
)
|
||||||
|
# logging.info(f"query_sym : {query_sym.shape}")
|
||||||
|
# logging.info(f"ranges : {ranges}")
|
||||||
|
# (B, T, s_range, attn_dim)
|
||||||
|
query_sym_pruned = torch.gather(
|
||||||
|
query_sym,
|
||||||
|
dim=1,
|
||||||
|
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
|
||||||
|
).reshape(B, T, s_range, attn_dim)
|
||||||
|
# (B, T, s_range, V + 1)
|
||||||
|
dist_masks_pruned = torch.gather(
|
||||||
|
dist_masks,
|
||||||
|
dim=1,
|
||||||
|
index=ranges.reshape(B, T * s_range, 1).expand(
|
||||||
|
(B, T * s_range, self.vocab_size + 1)
|
||||||
|
),
|
||||||
|
).reshape(B, T, s_range, self.vocab_size + 1)
|
||||||
|
# (B, T, s_range)
|
||||||
|
gen_masks_pruned = torch.gather(
|
||||||
|
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
|
||||||
|
).reshape(B, T, s_range)
|
||||||
|
return (
|
||||||
|
query_sym_pruned + query_acoustic_pruned,
|
||||||
|
dist_masks_pruned,
|
||||||
|
gen_masks_pruned,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
encoder_embeddings: torch.Tensor,
|
||||||
|
ranges: torch.Tensor,
|
||||||
|
context_graph: ContextGraph,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
target:
|
||||||
|
The training targets in token ids (padded with blanks). shape : (B, S)
|
||||||
|
encoder_embeddings:
|
||||||
|
The encoder outputs. shape: (B, T, attn_dim)
|
||||||
|
ranges:
|
||||||
|
The prune ranges from pruned rnnt. shape: (B, T, s_range)
|
||||||
|
context_graphs:
|
||||||
|
The context_graphs for each utterance. B == len(context_graphs).
|
||||||
|
|
||||||
|
Return:
|
||||||
|
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
|
||||||
|
tcpgen distribution with shape (B, T, s_range, V + 1).
|
||||||
|
"""
|
||||||
|
query_sym = self.decoder.embedding(targets)
|
||||||
|
|
||||||
|
query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
|
||||||
|
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)
|
||||||
|
|
||||||
|
# dist_masks : (B, S, V + 1)
|
||||||
|
# gen_masks : (B, S)
|
||||||
|
dist_masks, gen_masks = self.get_tcpgen_masks(
|
||||||
|
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
|
||||||
|
)
|
||||||
|
# query : (B, T, s_range, attn_dim)
|
||||||
|
# dist_masks : (B, T, s_range, V + 1)
|
||||||
|
query, dist_masks, gen_masks = self.prune_query_and_mask(
|
||||||
|
query_sym=query_sym,
|
||||||
|
query_acoustic=query_acoustic,
|
||||||
|
dist_masks=dist_masks,
|
||||||
|
gen_masks=gen_masks,
|
||||||
|
ranges=ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
if random.random() >= 0.95:
|
||||||
|
logging.info(
|
||||||
|
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# hptr : (B, T, s_range, attn_dim)
|
||||||
|
# tcpgen_dist : (B, T, s_range, V + 1)
|
||||||
|
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
|
||||||
|
return hptr, tcpgen_dist, gen_masks
|
||||||
|
|
||||||
|
def generator_prob(
|
||||||
|
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# tcpgen_prob : (B, T, s_range, 1)
|
||||||
|
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
|
||||||
|
tcpgen_prob = torch.sigmoid(tcpgen_prob)
|
||||||
|
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
|
||||||
|
if random.random() >= 0.95:
|
||||||
|
logging.info(
|
||||||
|
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
|
||||||
|
)
|
||||||
|
return tcpgen_prob
|
@ -17,6 +17,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
@ -37,6 +38,7 @@ class Joiner(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_out: torch.Tensor,
|
encoder_out: torch.Tensor,
|
||||||
decoder_out: torch.Tensor,
|
decoder_out: torch.Tensor,
|
||||||
|
tcpgen_hptr: Optional[torch.Tensor] = None,
|
||||||
project_input: bool = True,
|
project_input: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -62,6 +64,11 @@ class Joiner(nn.Module):
|
|||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
if tcpgen_hptr is not None:
|
||||||
|
logit += tcpgen_hptr
|
||||||
|
|
||||||
return logit
|
activations = torch.tanh(logit)
|
||||||
|
|
||||||
|
logit = self.output_linear(activations)
|
||||||
|
|
||||||
|
return logit if tcpgen_hptr is None else (logit, activations)
|
||||||
|
@ -17,14 +17,17 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
from biasing import TCPGen
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
|
from icefall import ContextGraph
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
@ -36,9 +39,13 @@ class AsrModel(nn.Module):
|
|||||||
joiner: Optional[nn.Module] = None,
|
joiner: Optional[nn.Module] = None,
|
||||||
encoder_dim: int = 384,
|
encoder_dim: int = 384,
|
||||||
decoder_dim: int = 512,
|
decoder_dim: int = 512,
|
||||||
|
joiner_dim: int = 512,
|
||||||
vocab_size: int = 500,
|
vocab_size: int = 500,
|
||||||
|
tcpgen_attn_dim: int = 512,
|
||||||
use_transducer: bool = True,
|
use_transducer: bool = True,
|
||||||
use_ctc: bool = False,
|
use_ctc: bool = False,
|
||||||
|
use_tcpgen_biasing: bool = False,
|
||||||
|
tcpgen_dropout: float = 0.15,
|
||||||
):
|
):
|
||||||
"""A joint CTC & Transducer ASR model.
|
"""A joint CTC & Transducer ASR model.
|
||||||
|
|
||||||
@ -111,6 +118,19 @@ class AsrModel(nn.Module):
|
|||||||
nn.LogSoftmax(dim=-1),
|
nn.LogSoftmax(dim=-1),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.use_tcpgen_biasing = use_tcpgen_biasing
|
||||||
|
if use_tcpgen_biasing:
|
||||||
|
assert use_transducer, "TCPGen biasing only support on transducer model."
|
||||||
|
self.tcp_gen = TCPGen(
|
||||||
|
encoder_dim=encoder_dim,
|
||||||
|
embed_dim=decoder_dim,
|
||||||
|
attn_dim=tcpgen_attn_dim,
|
||||||
|
joiner_dim=joiner_dim,
|
||||||
|
decoder=decoder,
|
||||||
|
tcpgen_dropout=tcpgen_dropout,
|
||||||
|
)
|
||||||
|
self.hptr_proj = torch.nn.Linear(tcpgen_attn_dim, joiner_dim)
|
||||||
|
|
||||||
def forward_encoder(
|
def forward_encoder(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -180,6 +200,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute Transducer loss.
|
"""Compute Transducer loss.
|
||||||
Args:
|
Args:
|
||||||
@ -202,6 +223,7 @@ class AsrModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
|
assert blank_id == 0, f"Assuming blank_id is 0, given {blank_id}"
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
# sos_y_padded: [B, S + 1], start with SOS.
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
@ -226,11 +248,6 @@ class AsrModel(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
|
||||||
# if self.training and random.random() < 0.25:
|
|
||||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
@ -252,14 +269,61 @@ class AsrModel(nn.Module):
|
|||||||
s_range=prune_range,
|
s_range=prune_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
# am_pruned : [B, T, prune_range, joiner_dim]
|
||||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
# lm_pruned : [B, T, prune_range, joiner_dim]
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=self.joiner.encoder_proj(encoder_out),
|
am=self.joiner.encoder_proj(encoder_out),
|
||||||
lm=self.joiner.decoder_proj(decoder_out),
|
lm=self.joiner.decoder_proj(decoder_out),
|
||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
fused_log_softmax = True
|
||||||
|
if self.use_tcpgen_biasing and context_graph is not None:
|
||||||
|
# logging.info(f"using tcpgen baising")
|
||||||
|
fused_log_softmax = False
|
||||||
|
# hptr : (B, T, s_range, attn_dim)
|
||||||
|
# tcpgen_dist : (B, T, s_range, V + 1)
|
||||||
|
# gen_masks : (B, T, s_range)
|
||||||
|
hptr, tcpgen_dist, gen_masks = self.tcp_gen(
|
||||||
|
targets=sos_y_padded,
|
||||||
|
encoder_embeddings=encoder_out,
|
||||||
|
ranges=ranges,
|
||||||
|
context_graph=context_graph,
|
||||||
|
)
|
||||||
|
# (B, T, s_range, joiner_dim)
|
||||||
|
tcpgen_hptr = self.hptr_proj(hptr)
|
||||||
|
# logits : (B, T, s_range, V)
|
||||||
|
# activations : (B, T, s_range, joiner_dim)
|
||||||
|
logits, activations = self.joiner(
|
||||||
|
encoder_out=am_pruned,
|
||||||
|
decoder_out=lm_pruned,
|
||||||
|
tcpgen_hptr=tcpgen_hptr,
|
||||||
|
project_input=False,
|
||||||
|
)
|
||||||
|
# (B, T, s_range, 1)
|
||||||
|
tcpgen_prob = self.tcp_gen.generator_prob(
|
||||||
|
hptr=hptr, h_joiner=activations, gen_masks=gen_masks
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assuming blank_id is 0
|
||||||
|
p_mdl = torch.softmax(logits, dim=-1)
|
||||||
|
p_no_blank = 1.0 - p_mdl[:, :, :, 0:1]
|
||||||
|
|
||||||
|
# blank dist (tcpgen_dist[:,:,:,0]) should be 0
|
||||||
|
scaled_tcpgen_dist = tcpgen_dist[:, :, :, 1:] * p_no_blank
|
||||||
|
|
||||||
|
# ool probability : tcpgen_dist[:, :, :, -1:]
|
||||||
|
scaled_tcpgen_prob = tcpgen_prob * (1 - tcpgen_dist[:, :, :, -1:])
|
||||||
|
|
||||||
|
interpolated_no_blank = scaled_tcpgen_dist[
|
||||||
|
:, :, :, :-1
|
||||||
|
] * tcpgen_prob + p_mdl[:, :, :, 1:] * (1 - scaled_tcpgen_prob)
|
||||||
|
|
||||||
|
final_dist = torch.cat([p_mdl[:, :, :, 0:1], interpolated_no_blank], dim=-1)
|
||||||
|
|
||||||
|
# actually logprobs
|
||||||
|
logits = torch.log(final_dist + torch.finfo(final_dist.dtype).tiny)
|
||||||
|
else:
|
||||||
# logits : [B, T, prune_range, vocab_size]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
@ -274,6 +338,7 @@ class AsrModel(nn.Module):
|
|||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
|
fused_log_softmax=fused_log_softmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
return simple_loss, pruned_loss
|
return simple_loss, pruned_loss
|
||||||
@ -286,6 +351,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
context_graph: Optional[ContextGraph] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -338,6 +404,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range=prune_range,
|
prune_range=prune_range,
|
||||||
am_scale=am_scale,
|
am_scale=am_scale,
|
||||||
lm_scale=lm_scale,
|
lm_scale=lm_scale,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
simple_loss = torch.empty(0)
|
simple_loss = torch.empty(0)
|
||||||
|
@ -54,10 +54,11 @@ It supports training with:
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -81,7 +82,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import ContextGraph, diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -259,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="If True, use CTC head.",
|
help="If True, use CTC head.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-tcpgen-biasing",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use tcpgen context biasing module",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -522,6 +530,7 @@ def get_params() -> AttributeDict:
|
|||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
"best_valid_epoch": -1,
|
"best_valid_epoch": -1,
|
||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
|
"epoch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
@ -530,6 +539,10 @@ def get_params() -> AttributeDict:
|
|||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
# parameters for tcpgen
|
||||||
|
"tcpgen_start_epoch": 10,
|
||||||
|
"num_distrators": 500,
|
||||||
|
"distractors_list": [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -624,9 +637,11 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
encoder_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
use_transducer=params.use_transducer,
|
use_transducer=params.use_transducer,
|
||||||
use_ctc=params.use_ctc,
|
use_ctc=params.use_ctc,
|
||||||
|
use_tcpgen_biasing=params.use_tcpgen_biasing,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -684,6 +699,7 @@ def load_checkpoint_if_available(
|
|||||||
"best_train_epoch",
|
"best_train_epoch",
|
||||||
"best_valid_epoch",
|
"best_valid_epoch",
|
||||||
"batch_idx_train",
|
"batch_idx_train",
|
||||||
|
"epoch_idx_train",
|
||||||
"best_train_loss",
|
"best_train_loss",
|
||||||
"best_valid_loss",
|
"best_valid_loss",
|
||||||
]
|
]
|
||||||
@ -747,6 +763,40 @@ def save_checkpoint(
|
|||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_context_graph(
|
||||||
|
params: AttributeDict,
|
||||||
|
texts: List[str],
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> ContextGraph:
|
||||||
|
if params.epoch_idx_train == params.start_epoch:
|
||||||
|
params.distractors_list += texts
|
||||||
|
return None
|
||||||
|
|
||||||
|
logging.info(f"distractors_list : {len(params.distractors_list)}")
|
||||||
|
if params.epoch_idx_train >= params.tcpgen_start_epoch:
|
||||||
|
# logging.info("prepare context graph")
|
||||||
|
contexts_list = []
|
||||||
|
selected_texts = []
|
||||||
|
for i, text in enumerate(texts):
|
||||||
|
if random.random() >= 0.5:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
selected_texts.append(text)
|
||||||
|
for i in range(params.num_distrators):
|
||||||
|
index = random.randint(0, len(params.distractors_list) - 1)
|
||||||
|
selected_texts.append(params.distractors_list[index])
|
||||||
|
for st in selected_texts:
|
||||||
|
word_list = st.split()
|
||||||
|
start = random.randint(0, len(word_list) - 1)
|
||||||
|
length = random.randint(1, 3)
|
||||||
|
contexts_list.append(" ".join(word_list[start : start + length]))
|
||||||
|
contexts_tokens = sp.encode(contexts_list, out_type=int)
|
||||||
|
# logging.info(f"contexts_list : {contexts_list}, tokens : {contexts_tokens}")
|
||||||
|
context_graph = ContextGraph(0)
|
||||||
|
context_graph.build(contexts_tokens)
|
||||||
|
return context_graph
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -788,6 +838,11 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
|
context_graph = None
|
||||||
|
if params.use_tcpgen_biasing:
|
||||||
|
context_graph = prepare_context_graph(params=params, texts=texts, sp=sp)
|
||||||
|
# assert context_graph is not None
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -796,6 +851,7 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
context_graph=context_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -935,6 +991,8 @@ def train_one_epoch(
|
|||||||
rank=0,
|
rank=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.info(f"epoch_idx_train : {params.epoch_idx_train}")
|
||||||
|
params.epoch_idx_train += 1
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx % 10 == 0:
|
if batch_idx % 10 == 0:
|
||||||
set_batch_count(model, get_adjusted_batch_count(params))
|
set_batch_count(model, get_adjusted_batch_count(params))
|
||||||
@ -963,9 +1021,7 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logging.info(f"Caught exception: {e}.")
|
||||||
f"Caught exception: {e}."
|
|
||||||
)
|
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
@ -1173,16 +1229,16 @@ def run(rank, world_size, args):
|
|||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
if params.full_libri:
|
if params.full_libri:
|
||||||
train_cuts = librispeech.train_all_shuf_cuts()
|
# train_cuts = librispeech.train_all_shuf_cuts()
|
||||||
|
|
||||||
# previously we used the following code to load all training cuts,
|
# previously we used the following code to load all training cuts,
|
||||||
# strictly speaking, shuffled training cuts should be used instead,
|
# strictly speaking, shuffled training cuts should be used instead,
|
||||||
# but we leave the code here to demonstrate that there is an option
|
# but we leave the code here to demonstrate that there is an option
|
||||||
# like this to combine multiple cutsets
|
# like this to combine multiple cutsets
|
||||||
|
|
||||||
# train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
# train_cuts += librispeech.train_clean_360_cuts()
|
train_cuts += librispeech.train_clean_360_cuts()
|
||||||
# train_cuts += librispeech.train_other_500_cuts()
|
train_cuts += librispeech.train_other_500_cuts()
|
||||||
else:
|
else:
|
||||||
train_cuts = librispeech.train_clean_100_cuts()
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
|
||||||
@ -1201,26 +1257,6 @@ def run(rank, world_size, args):
|
|||||||
# )
|
# )
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
|
||||||
# where T is the number of feature frames after subsampling
|
|
||||||
# and S is the number of tokens in the utterance
|
|
||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
|
||||||
# for subsampling
|
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
|
||||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
|
||||||
|
|
||||||
if T < len(tokens):
|
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. "
|
|
||||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
|
||||||
f"Number of frames (after subsampling): {T}. "
|
|
||||||
f"Text: {c.supervisions[0].text}. "
|
|
||||||
f"Tokens: {tokens}. "
|
|
||||||
f"Number of tokens: {len(tokens)}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
@ -1241,13 +1277,14 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
pass
|
||||||
model=model,
|
# scan_pessimistic_batches_for_oom(
|
||||||
train_dl=train_dl,
|
# model=model,
|
||||||
optimizer=optimizer,
|
# train_dl=train_dl,
|
||||||
sp=sp,
|
# optimizer=optimizer,
|
||||||
params=params,
|
# sp=sp,
|
||||||
)
|
# params=params,
|
||||||
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1393,4 +1430,5 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
torch.set_printoptions(profile="full")
|
||||||
main()
|
main()
|
||||||
|
@ -162,6 +162,7 @@ class ContextGraph:
|
|||||||
phrases: Optional[List[str]] = None,
|
phrases: Optional[List[str]] = None,
|
||||||
scores: Optional[List[float]] = None,
|
scores: Optional[List[float]] = None,
|
||||||
ac_thresholds: Optional[List[float]] = None,
|
ac_thresholds: Optional[List[float]] = None,
|
||||||
|
simple_trie: bool = False,
|
||||||
):
|
):
|
||||||
"""Build the ContextGraph from a list of token list.
|
"""Build the ContextGraph from a list of token list.
|
||||||
It first build a trie from the given token lists, then fill the fail arc
|
It first build a trie from the given token lists, then fill the fail arc
|
||||||
@ -189,6 +190,11 @@ class ContextGraph:
|
|||||||
0 means using the default value (i.e. self.ac_threshold). It is
|
0 means using the default value (i.e. self.ac_threshold). It is
|
||||||
used only when this graph applied for the keywords spotting system.
|
used only when this graph applied for the keywords spotting system.
|
||||||
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
|
The length of `ac_threshold` MUST be equal to the length of `token_ids`.
|
||||||
|
simple_trie:
|
||||||
|
True for building only trie (i.e. no fail and output arcs). Needed by
|
||||||
|
tcpgen biasing training.
|
||||||
|
False for building a Aho-corasick automata, for hotword / keywords
|
||||||
|
searching.
|
||||||
|
|
||||||
Note: The phrases would have shared states, the score of the shared states is
|
Note: The phrases would have shared states, the score of the shared states is
|
||||||
the MAXIMUM value among all the tokens sharing this state.
|
the MAXIMUM value among all the tokens sharing this state.
|
||||||
@ -211,7 +217,6 @@ class ContextGraph:
|
|||||||
context_score = self.context_score if score == 0.0 else score
|
context_score = self.context_score if score == 0.0 else score
|
||||||
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
|
threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold
|
||||||
for i, token in enumerate(tokens):
|
for i, token in enumerate(tokens):
|
||||||
node_next = {}
|
|
||||||
if token not in node.next:
|
if token not in node.next:
|
||||||
self.num_nodes += 1
|
self.num_nodes += 1
|
||||||
is_end = i == len(tokens) - 1
|
is_end = i == len(tokens) - 1
|
||||||
@ -240,6 +245,8 @@ class ContextGraph:
|
|||||||
node.next[token].phrase = phrase
|
node.next[token].phrase = phrase
|
||||||
node.next[token].ac_threshold = threshold
|
node.next[token].ac_threshold = threshold
|
||||||
node = node.next[token]
|
node = node.next[token]
|
||||||
|
|
||||||
|
if not simple_trie:
|
||||||
self._fill_fail_output()
|
self._fill_fail_output()
|
||||||
|
|
||||||
def forward_one_step(
|
def forward_one_step(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user