icefall/egs/libriheavy/LM/zipformer1/chunk_decoder.py

125 lines
5.0 KiB
Python

#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import torch
from torch import nn, Tensor
class ChunkDecoder(nn.Module):
"""
"""
def __init__(self,
embed_dim: int,
chunk_size: int,
vocab_size: int,
hidden_size: int,
num_layers: int = 2):
"""
A 'decoder' that computes the probability of symbols in a language modeling task.
Conceptually it computes the probability of `chunk_size` symbols (e.g. 8 symbols)
based on an embedding derived from all symbols preceding this chunk of 8 symbols.
Also, within the chunk, we always see all previous symbols (plus the last symbol
of the previous chunk).
"""
super().__init__()
self.chunk_size = chunk_size
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers)
self.label_embed = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embed_dim)
# project to hidden state and cell state at the beginning of each chunk.
# (we don't run the lstm contiuously over the sequence, for both
# training speed and stability; instead, we reconstruct the hidden
# state for each chunk.)
# the factor of 2 is to cover hidden state and cell state.
self.init_proj = nn.Linear(embed_dim,
2 * hidden_size * num_layers)
self.out_proj = nn.Linear(hidden_size,
vocab_size)
def forward(self,
labels: Tensor,
encoder_embed: Tensor) -> Tensor:
"""
Compute log-probs.
Args:
labels: the labels, a Tensor of integer type of shape (batch_size, seq_len);
seq_len is expected to be a multiple of chunk_size.
encoder_embed: the embeddings from the encoder, of shape (seq_len//chunk_size, batch_size, embed_dim)
Returns:
returns the log-probs for each symbol, in a Tensor of shape (batch_size, seq_len).
"""
(batch_size, seq_len) = labels.shape
(num_chunks, _batch_size, embed_dim) = encoder_embed.shape
chunk_size = self.chunk_size
assert batch_size == _batch_size
assert num_chunks * chunk_size == seq_len
labels_shifted = torch.cat((torch.zeros_like(labels[0:1]),
labels[:-1]), dim=0)
labels_embed = self.label_embed(labels_shifted.t()) # (seq_len, batch_size, embed_dim)
init = self.init_proj(encoder_embed).reshape(num_chunks, batch_size,
2, self.num_layers, self.hidden_size)
init = init.permute(2, 3, 0, 1, 4).reshape(2, self.num_layers,
num_chunks * batch_size,
self.hidden_size).contiguous()
hidden = init[0]
cell = init[1]
labels_embed = labels_embed.reshape(num_chunks, chunk_size, batch_size, embed_dim).transpose(0, 1)
labels_embed = labels_embed.contiguous().reshape(chunk_size, num_chunks * batch_size, embed_dim)
encoder_embed = encoder_embed.reshape(1, num_chunks * batch_size, embed_dim)
x = labels_embed + encoder_embed # broadcasts encoder_embed over the chunk_size
(x, _hidden) = self.lstm(x, hx=(hidden, cell))
x = self.out_proj(x)
vocab_size = x.shape[-1]
# x: (chunk_size, num_chunks * batch_size, vocab_size)
x = x.reshape(chunk_size, num_chunks, batch_size, vocab_size)
x = x.permute(2, 1, 0, 3).contiguous().reshape(batch_size, num_chunks * chunk_size, vocab_size)
x = x.log_softmax(dim=-1)
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
if random.random() < 0.02:
# occasionally print out average logprob per position in the chunk.
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
l = l.to('cpu').tolist()
logging.info(f"Logprobs per position in chunk: {l}")
return logprobs