mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
125 lines
5.0 KiB
Python
125 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
import logging
|
|
import random
|
|
import torch
|
|
from torch import nn, Tensor
|
|
|
|
|
|
class ChunkDecoder(nn.Module):
|
|
"""
|
|
"""
|
|
def __init__(self,
|
|
embed_dim: int,
|
|
chunk_size: int,
|
|
vocab_size: int,
|
|
hidden_size: int,
|
|
num_layers: int = 2):
|
|
"""
|
|
A 'decoder' that computes the probability of symbols in a language modeling task.
|
|
Conceptually it computes the probability of `chunk_size` symbols (e.g. 8 symbols)
|
|
based on an embedding derived from all symbols preceding this chunk of 8 symbols.
|
|
Also, within the chunk, we always see all previous symbols (plus the last symbol
|
|
of the previous chunk).
|
|
"""
|
|
super().__init__()
|
|
self.chunk_size = chunk_size
|
|
self.num_layers = num_layers
|
|
self.hidden_size = hidden_size
|
|
|
|
self.lstm = nn.LSTM(input_size=embed_dim,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers)
|
|
|
|
self.label_embed = nn.Embedding(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=embed_dim)
|
|
|
|
# project to hidden state and cell state at the beginning of each chunk.
|
|
# (we don't run the lstm contiuously over the sequence, for both
|
|
# training speed and stability; instead, we reconstruct the hidden
|
|
# state for each chunk.)
|
|
# the factor of 2 is to cover hidden state and cell state.
|
|
self.init_proj = nn.Linear(embed_dim,
|
|
2 * hidden_size * num_layers)
|
|
|
|
self.out_proj = nn.Linear(hidden_size,
|
|
vocab_size)
|
|
|
|
|
|
def forward(self,
|
|
labels: Tensor,
|
|
encoder_embed: Tensor) -> Tensor:
|
|
"""
|
|
Compute log-probs.
|
|
Args:
|
|
labels: the labels, a Tensor of integer type of shape (batch_size, seq_len);
|
|
seq_len is expected to be a multiple of chunk_size.
|
|
encoder_embed: the embeddings from the encoder, of shape (seq_len//chunk_size, batch_size, embed_dim)
|
|
|
|
Returns:
|
|
returns the log-probs for each symbol, in a Tensor of shape (batch_size, seq_len).
|
|
"""
|
|
(batch_size, seq_len) = labels.shape
|
|
(num_chunks, _batch_size, embed_dim) = encoder_embed.shape
|
|
chunk_size = self.chunk_size
|
|
assert batch_size == _batch_size
|
|
assert num_chunks * chunk_size == seq_len
|
|
|
|
labels_shifted = torch.cat((torch.zeros_like(labels[0:1]),
|
|
labels[:-1]), dim=0)
|
|
|
|
labels_embed = self.label_embed(labels_shifted.t()) # (seq_len, batch_size, embed_dim)
|
|
|
|
init = self.init_proj(encoder_embed).reshape(num_chunks, batch_size,
|
|
2, self.num_layers, self.hidden_size)
|
|
init = init.permute(2, 3, 0, 1, 4).reshape(2, self.num_layers,
|
|
num_chunks * batch_size,
|
|
self.hidden_size).contiguous()
|
|
hidden = init[0]
|
|
cell = init[1]
|
|
|
|
|
|
labels_embed = labels_embed.reshape(num_chunks, chunk_size, batch_size, embed_dim).transpose(0, 1)
|
|
labels_embed = labels_embed.contiguous().reshape(chunk_size, num_chunks * batch_size, embed_dim)
|
|
encoder_embed = encoder_embed.reshape(1, num_chunks * batch_size, embed_dim)
|
|
|
|
x = labels_embed + encoder_embed # broadcasts encoder_embed over the chunk_size
|
|
|
|
(x, _hidden) = self.lstm(x, hx=(hidden, cell))
|
|
|
|
x = self.out_proj(x)
|
|
|
|
vocab_size = x.shape[-1]
|
|
# x: (chunk_size, num_chunks * batch_size, vocab_size)
|
|
x = x.reshape(chunk_size, num_chunks, batch_size, vocab_size)
|
|
x = x.permute(2, 1, 0, 3).contiguous().reshape(batch_size, num_chunks * chunk_size, vocab_size)
|
|
|
|
x = x.log_softmax(dim=-1)
|
|
|
|
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
|
|
|
if random.random() < 0.02:
|
|
# occasionally print out average logprob per position in the chunk.
|
|
l = logprobs.reshape(batch_size, num_chunks, chunk_size).mean(dim=(0, 1))
|
|
l = l.to('cpu').tolist()
|
|
logging.info(f"Logprobs per position in chunk: {l}")
|
|
|
|
return logprobs
|