2024-06-26 11:42:17 +08:00

285 lines
11 KiB
Python

# Copyright 2024 Xiaomi Corp. (authors: Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import math
import torch
import torch.nn as nn
from icefall import ContextGraph
import logging
import random
class TCPGen(nn.Module):
def __init__(
self,
encoder_dim: int,
embed_dim: int,
joiner_dim: int,
decoder: nn.Module,
attn_dim: int = 512,
tcpgen_dropout: float = 0.1,
):
super().__init__()
# embedding for ool (out of biasing list) token
self.oolemb = torch.nn.Embedding(1, embed_dim)
# project symbol embeddings
self.q_proj_sym = torch.nn.Linear(embed_dim, attn_dim)
# project encoder embeddings
self.q_proj_acoustic = torch.nn.Linear(encoder_dim, attn_dim)
# project symbol embeddings (vocabulary + ool)
self.k_proj = torch.nn.Linear(embed_dim, attn_dim)
# generate tcpgen probability
self.tcp_gate = torch.nn.Linear(attn_dim + joiner_dim, 1)
self.dropout_tcpgen = torch.nn.Dropout(tcpgen_dropout)
self.decoder = decoder
self.vocab_size = decoder.vocab_size
def get_tcpgen_masks(
self,
targets: torch.Tensor,
context_graph: ContextGraph,
vocab_size: int,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, sql_len = targets.shape
dist_masks = torch.ones((batch_size, sql_len, vocab_size + 1))
gen_masks = []
yseqs = targets.tolist()
for i, yseq in enumerate(yseqs):
node = context_graph.root
gen_mask = []
for j, y in enumerate(yseq):
not_matched = False
if y == blank_id:
node = context_graph.root
gen_mask.append(0)
elif y in node.next:
gen_mask.append(0)
node = node.next[y]
if node.is_end:
node = context_graph.root
else:
gen_mask.append(1)
node = context_graph.root
not_matched = True
# unmask_index = (
# [vocab_size]
# if node.token == -1
# else list(node.next.keys()) + [vocab_size]
# )
# logging.info(f"token : {node.token}, keys : {node.next.keys()}")
# dist_masks[i, j, unmask_index] = 0
if not not_matched:
dist_masks[i, j, list(node.next.keys())] = 0
gen_masks.append(gen_mask + [1] * (sql_len - len(gen_mask)))
gen_masks = torch.Tensor(gen_masks).to(targets.device).bool()
dist_masks = dist_masks.to(targets.device).bool()
if random.random() >= 0.95:
logging.info(
f"gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=2)}"
)
return dist_masks, gen_masks
def get_tcpgen_distribution(
self, query: torch.Tensor, dist_masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query:
shape : (B, T, s_range, attn_dim)
dist_masks:
shape : (B, T, s_range, V + 1)
"""
# From original paper, k, v share same embeddings
# (V + 1, embed_dim)
kv = torch.cat([self.decoder.embedding.weight.data, self.oolemb.weight], dim=0)
# (V + 1, attn_dim)
kv = self.dropout_tcpgen(self.k_proj(kv))
# (B, T, s_range, attn_dim) * (attn_dim, V + 1) -> (B, T, s_range, V + 1)
distribution = torch.matmul(query, kv.permute(1, 0)) / math.sqrt(query.size(-1))
distribution = distribution.masked_fill(
dist_masks, torch.finfo(distribution.dtype).min
)
distribution = distribution.softmax(dim=-1)
# (B, T, s_range, V) * (V, attn_dim) -> (B, T, s_range, attn_dim)
# logging.info(f"distribution shape : {distribution.shape}")
hptr = torch.matmul(distribution[:, :, :, :-1], kv[:-1, :])
hptr = self.dropout_tcpgen(hptr)
if random.random() > 0.95:
logging.info(
f"distribution mean : {torch.mean(distribution, dim=3)}, std: {torch.std(distribution, dim=3)}"
)
logging.info(
f"distribution min : {torch.min(distribution, dim=3)}, max: {torch.max(distribution, dim=3)}"
)
return hptr, distribution
def prune_query_and_mask(
self,
query_sym: torch.Tensor,
query_acoustic: torch.Tensor,
dist_masks: torch.Tensor,
gen_masks: torch.Tensor,
ranges: torch.Tensor,
):
"""Prune the queries from symbols and acoustics with ranges
generated by `get_rnnt_prune_ranges` in pruned rnnt loss.
Args:
query_sym:
The symbol query, with shape (B, S, attn_dim).
query_acoustic:
The acoustic query, with shape (B, T, attn_dim).
dist_masks:
The TCPGen distribution masks, with shape (B, S, V + 1).
gen_masks:
The TCPGen probability masks, with shape (B, S).
ranges:
A tensor containing the symbol indexes for each frame that we want to
keep. Its shape is (B, T, s_range), see the docs in
`get_rnnt_prune_ranges` in rnnt_loss.py for more details of this tensor.
Returns:
Return the pruned query with the shape (B, T, s_range, attn_dim).
"""
assert ranges.shape[0] == query_sym.shape[0], (ranges.shape, query_sym.shape)
assert ranges.shape[0] == query_acoustic.shape[0], (
ranges.shape,
query_acoustic.shape,
)
assert query_acoustic.shape[1] == ranges.shape[1], (
query_acoustic.shape,
ranges.shape,
)
(B, T, s_range) = ranges.shape
(B, S, attn_dim) = query_sym.shape
assert query_acoustic.shape == (B, T, attn_dim), (
query_acoustic.shape,
(B, T, attn_dim),
)
assert dist_masks.shape == (B, S, self.vocab_size + 1), (
dist_masks.shape,
(B, S, self.vocab_size + 1),
)
assert gen_masks.shape == (B, S), (gen_masks.shape, (B, S))
# (B, T, s_range, attn_dim)
query_acoustic_pruned = query_acoustic.unsqueeze(2).expand(
(B, T, s_range, attn_dim)
)
# logging.info(f"query_sym : {query_sym.shape}")
# logging.info(f"ranges : {ranges}")
# (B, T, s_range, attn_dim)
query_sym_pruned = torch.gather(
query_sym,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, attn_dim)),
).reshape(B, T, s_range, attn_dim)
# (B, T, s_range, V + 1)
dist_masks_pruned = torch.gather(
dist_masks,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, self.vocab_size + 1)
),
).reshape(B, T, s_range, self.vocab_size + 1)
# (B, T, s_range)
gen_masks_pruned = torch.gather(
gen_masks, dim=1, index=ranges.reshape(B, T * s_range)
).reshape(B, T, s_range)
return (
query_sym_pruned + query_acoustic_pruned,
dist_masks_pruned,
gen_masks_pruned,
)
def forward(
self,
targets: torch.Tensor,
encoder_embeddings: torch.Tensor,
ranges: torch.Tensor,
context_graph: ContextGraph,
):
"""
Args:
target:
The training targets in token ids (padded with blanks). shape : (B, S)
encoder_embeddings:
The encoder outputs. shape: (B, T, attn_dim)
ranges:
The prune ranges from pruned rnnt. shape: (B, T, s_range)
context_graphs:
The context_graphs for each utterance. B == len(context_graphs).
Return:
returns tcpgen embedding with shape (B, T, s_range, attn_dim) and
tcpgen distribution with shape (B, T, s_range, V + 1).
"""
query_sym = self.decoder.embedding(targets)
query_sym = self.q_proj_sym(query_sym) # (B, S, attn_dim)
query_acoustic = self.q_proj_acoustic(encoder_embeddings) # (B , T, attn_dim)
# dist_masks : (B, S, V + 1)
# gen_masks : (B, S)
dist_masks, gen_masks = self.get_tcpgen_masks(
targets=targets, context_graph=context_graph, vocab_size=self.vocab_size
)
# query : (B, T, s_range, attn_dim)
# dist_masks : (B, T, s_range, V + 1)
query, dist_masks, gen_masks = self.prune_query_and_mask(
query_sym=query_sym,
query_acoustic=query_acoustic,
dist_masks=dist_masks,
gen_masks=gen_masks,
ranges=ranges,
)
if random.random() >= 0.95:
logging.info(
f"pruned gen_mask nonzero {gen_masks.shape} : {torch.count_nonzero(torch.logical_not(gen_masks), dim=1)}"
)
logging.info(
f"pruned dist_masks nonzero {dist_masks.shape} : {torch.count_nonzero(torch.logical_not(dist_masks), dim=3)}"
)
# hptr : (B, T, s_range, attn_dim)
# tcpgen_dist : (B, T, s_range, V + 1)
hptr, tcpgen_dist = self.get_tcpgen_distribution(query, dist_masks)
return hptr, tcpgen_dist, gen_masks
def generator_prob(
self, hptr: torch.Tensor, h_joiner: torch.Tensor, gen_masks: torch.Tensor
) -> torch.Tensor:
# tcpgen_prob : (B, T, s_range, 1)
tcpgen_prob = self.tcp_gate(torch.cat((h_joiner, hptr), dim=-1))
tcpgen_prob = torch.sigmoid(tcpgen_prob)
tcpgen_prob = tcpgen_prob.masked_fill(gen_masks.unsqueeze(-1), 0)
if random.random() >= 0.95:
logging.info(
f"tcpgen_prob mean : {torch.mean(tcpgen_prob.squeeze(-1), dim=(1,2))}, std : {torch.std(tcpgen_prob.squeeze(-1), dim=(1, 2))}"
)
logging.info(
f"tcpgen_prob min : {torch.min(tcpgen_prob.squeeze(-1), dim=1)}, max : {torch.max(tcpgen_prob.squeeze(-1), dim=1)}"
)
return tcpgen_prob