mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
285 lines
11 KiB
Python
285 lines
11 KiB
Python
# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from icefall import ContextGraph
|
|
import logging
|
|
import random
|
|
|
|
|
|
class TCPGen(nn.Module):
|
|
def __init__(
|
|
self,
|
|
encoder_dim: int,
|
|
embed_dim: int,
|
|
joiner_dim: int,
|
|
decoder: nn.Module,
|
|
attn_dim: int = 512,
|
|
tcpgen_dropout: float = 0.1,
|
|
):
|
|
super().__init__()
|
|
# embedding for ool (out of biasing list) token
|
|
self.oolemb = torch.nn.Embedding(1, embed_dim)
|
|
# project symbol embeddings
|
|
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
|
|
# project encoder embeddings
|
|
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
|
|
# project symbol embeddings (vocabulary + ool)
|
|
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
|
|
# generate tcpgen probability
|
|
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
|
|
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
|
|
self.decoder = decoder
|
|
self.vocab_size = decoder.vocab_size
|
|
|
|
def get_tcpgen_masks(
|
|
self,
|
|
targets: torch.Tensor,
|
|
context_graph: ContextGraph,
|
|
vocab_size: int,
|
|
blank_id: int = 0,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
batch_size, sql_len = targets.shape
|
|
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
|
|
gen_masks = []
|
|
yseqs = targets.tolist()
|
|
for i, yseq in enumerate(yseqs):
|
|
node = context_graph.root
|
|
gen_mask = []
|
|
for j, y in enumerate(yseq):
|
|
not_matched = False
|
|
if y == blank_id:
|
|
node = context_graph.root
|
|
gen_mask.append(0)
|
|
elif y in node.next:
|
|
gen_mask.append(0)
|
|
node = node.next[y]
|
|
if node.is_end:
|
|
node = context_graph.root
|
|
else:
|
|
gen_mask.append(1)
|
|
node = context_graph.root
|
|
not_matched = True
|
|
# unmask_index = (
|
|
# [vocab_size]
|
|
# if node.token == -1
|
|
# else list(node.next.keys()) + [vocab_size]
|
|
# )
|
|
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
|
|
# dist_masks[i, j, unmask_index] = 0
|
|
if not not_matched:
|
|
dist_masks[i, j, list(node.next.keys())] = 0
|
|
|
|
gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
|
|
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
|
|
dist_masks = dist_masks.to(targets.device).bool()
|
|
if random.random() >= 0.95:
|
|
logging.info(
|
|
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
|
|
)
|
|
logging.info(
|
|
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
|
|
)
|
|
return dist_masks, gen_masks
|
|
|
|
def get_tcpgen_distribution(
|
|
self, query: torch.Tensor, dist_masks: torch.Tensor
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Args:
|
|
query:
|
|
shape : (B, T, s_range, attn_dim)
|
|
dist_masks:
|
|
shape : (B, T, s_range, V + 1)
|
|
"""
|
|
# From original paper, k, v share same embeddings
|
|
# (V + 1, embed_dim)
|
|
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
|
|
# (V + 1, attn_dim)
|
|
kv = self.dropout_tcpgen(self.k_proj(kv))
|
|
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
|
|
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
|
|
distribution = distribution.masked_fill(
|
|
dist_masks, torch.finfo(distribution.dtype).min
|
|
)
|
|
distribution = distribution.softmax(dim=-1)
|
|
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
|
|
# logging.info(f"distribution shape : {distribution.shape}")
|
|
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
|
|
hptr = self.dropout_tcpgen(hptr)
|
|
if random.random() > 0.95:
|
|
logging.info(
|
|
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
|
|
)
|
|
logging.info(
|
|
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
|
|
)
|
|
return hptr, distribution
|
|
|
|
def prune_query_and_mask(
|
|
self,
|
|
query_sym: torch.Tensor,
|
|
query_acoustic: torch.Tensor,
|
|
dist_masks: torch.Tensor,
|
|
gen_masks: torch.Tensor,
|
|
ranges: torch.Tensor,
|
|
):
|
|
"""Prune the queries from symbols and acoustics with ranges
|
|
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.
|
|
|
|
Args:
|
|
query_sym:
|
|
The symbol query, with shape (B, S, attn_dim).
|
|
query_acoustic:
|
|
The acoustic query, with shape (B, T, attn_dim).
|
|
dist_masks:
|
|
The TCPGen distribution masks, with shape (B, S, V + 1).
|
|
gen_masks:
|
|
The TCPGen probability masks, with shape (B, S).
|
|
ranges:
|
|
A tensor containing the symbol indexes for each frame that we want to
|
|
keep. Its shape is (B, T, s_range), see the docs in
|
|
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.
|
|
|
|
Returns:
|
|
Return the pruned query with the shape (B, T, s_range, attn_dim).
|
|
"""
|
|
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
|
|
assert ranges.shape[0] == query_acoustic.shape[0], (
|
|
ranges.shape,
|
|
query_acoustic.shape,
|
|
)
|
|
assert query_acoustic.shape[1] == ranges.shape[1], (
|
|
query_acoustic.shape,
|
|
ranges.shape,
|
|
)
|
|
(B, T, s_range) = ranges.shape
|
|
(B, S, attn_dim) = query_sym.shape
|
|
assert query_acoustic.shape == (B, T, attn_dim), (
|
|
query_acoustic.shape,
|
|
(B, T, attn_dim),
|
|
)
|
|
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
|
|
dist_masks.shape,
|
|
(B, S, self.vocab_size + 1),
|
|
)
|
|
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
|
|
# (B, T, s_range, attn_dim)
|
|
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
|
|
(B, T, s_range, attn_dim)
|
|
)
|
|
# logging.info(f"query_sym : {query_sym.shape}")
|
|
# logging.info(f"ranges : {ranges}")
|
|
# (B, T, s_range, attn_dim)
|
|
query_sym_pruned = torch.gather(
|
|
query_sym,
|
|
dim=1,
|
|
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
|
|
).reshape(B, T, s_range, attn_dim)
|
|
# (B, T, s_range, V + 1)
|
|
dist_masks_pruned = torch.gather(
|
|
dist_masks,
|
|
dim=1,
|
|
index=ranges.reshape(B, T * s_range, 1).expand(
|
|
(B, T * s_range, self.vocab_size + 1)
|
|
),
|
|
).reshape(B, T, s_range, self.vocab_size + 1)
|
|
# (B, T, s_range)
|
|
gen_masks_pruned = torch.gather(
|
|
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
|
|
).reshape(B, T, s_range)
|
|
return (
|
|
query_sym_pruned + query_acoustic_pruned,
|
|
dist_masks_pruned,
|
|
gen_masks_pruned,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
targets: torch.Tensor,
|
|
encoder_embeddings: torch.Tensor,
|
|
ranges: torch.Tensor,
|
|
context_graph: ContextGraph,
|
|
):
|
|
"""
|
|
Args:
|
|
target:
|
|
The training targets in token ids (padded with blanks). shape : (B, S)
|
|
encoder_embeddings:
|
|
The encoder outputs. shape: (B, T, attn_dim)
|
|
ranges:
|
|
The prune ranges from pruned rnnt. shape: (B, T, s_range)
|
|
context_graphs:
|
|
The context_graphs for each utterance. B == len(context_graphs).
|
|
|
|
Return:
|
|
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
|
|
tcpgen distribution with shape (B, T, s_range, V + 1).
|
|
"""
|
|
query_sym = self.decoder.embedding(targets)
|
|
|
|
query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
|
|
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)
|
|
|
|
# dist_masks : (B, S, V + 1)
|
|
# gen_masks : (B, S)
|
|
dist_masks, gen_masks = self.get_tcpgen_masks(
|
|
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
|
|
)
|
|
# query : (B, T, s_range, attn_dim)
|
|
# dist_masks : (B, T, s_range, V + 1)
|
|
query, dist_masks, gen_masks = self.prune_query_and_mask(
|
|
query_sym=query_sym,
|
|
query_acoustic=query_acoustic,
|
|
dist_masks=dist_masks,
|
|
gen_masks=gen_masks,
|
|
ranges=ranges,
|
|
)
|
|
|
|
if random.random() >= 0.95:
|
|
logging.info(
|
|
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
|
|
)
|
|
logging.info(
|
|
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
|
|
)
|
|
|
|
# hptr : (B, T, s_range, attn_dim)
|
|
# tcpgen_dist : (B, T, s_range, V + 1)
|
|
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
|
|
return hptr, tcpgen_dist, gen_masks
|
|
|
|
def generator_prob(
|
|
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
|
|
) -> torch.Tensor:
|
|
# tcpgen_prob : (B, T, s_range, 1)
|
|
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
|
|
tcpgen_prob = torch.sigmoid(tcpgen_prob)
|
|
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
|
|
if random.random() >= 0.95:
|
|
logging.info(
|
|
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
|
|
)
|
|
logging.info(
|
|
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
|
|
)
|
|
return tcpgen_prob
|