2022-03-06 23:14:03 +08:00

188 lines
5.4 KiB
Python

# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional
import torch
from model import Transducer
# TODO(fangjun): Add more documentation
# The force alignment problem can be formulated as find
# a path in a rectangular lattice, where the path starts
# from the lower left corner and ends at the upper right
# corner. The horizontal axis of the lattice is `t`
# and the vertical axis is `u`.
#
# AlignItem is a node in the lattice, where its
# len(ys) equals to `t` and pos_u is the u coordinate
# in the lattice.
@dataclass
class AlignItem:
log_prob: float
ys: List[int]
pos_u: int
class AlignItemList:
def __init__(self, items: Optional[List[AlignItem]] = None):
if items is None:
items = []
self.data = items
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)
def __getitem__(self, i: int) -> AlignItem:
return self.data[i]
def append(self, item: AlignItem) -> None:
self.data.append(item)
def get_active_items(self, T: int, U: int) -> "AlignItemList":
ans = []
for item in self:
t = len(item.ys)
if U - item.pos_u > T - t:
continue
ans.append(item)
return AlignItemList(ans)
def get_decoder_input(
self,
ys: List[int],
context_size: int,
blank_id: int,
) -> List[List[int]]:
ans: List[List[int]] = []
buf = [blank_id] * context_size + ys
for item in self:
# fmt: off
ans.append(buf[item.pos_u:(item.pos_u + context_size)])
# fmt: on
return ans
def topk(self, k: int) -> "AlignItemList":
items = list(self)
items = sorted(items, key=lambda i: i.log_prob, reverse=True)
return AlignItemList(items[:k])
def force_alignment(
model: Transducer,
encoder_out: torch.Tensor,
ys: List[int],
beam_size: int = 4,
) -> List[int]:
"""
Args:
model:
The transducer model.
encoder_out:
A tensor of shape (N, T, C). Support only for N==1 now.
ys:
A list of token IDs. We require that len(ys) <= T.
beam:
Size of the beam used in beam search.
Returns:
Return a list of int such that
- len(ans) == T
- After removing blanks from ans, we have ans == ys.
"""
import pdb
pdb.set_trace()
assert encoder_out.ndim == 3, encoder_out.ndim
assert encoder_out.size(0) == 1, encoder_out.size(0)
assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1))
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
device = model.device
T = encoder_out.size(1)
U = len(ys)
encoder_out_len = torch.tensor([1])
decoder_out_len = encoder_out_len
start = AlignItem(log_prob=0.0, ys=[], pos_u=0)
B = AlignItemList([start])
for t in range(T):
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# current_encoder_out is of shape (1, 1, encoder_out_dim)
# fmt: on
# A = B.get_active_items()
A = B # shallow copy
B = AlignItemList()
decoder_input = A.get_decoder_input(
ys=ys, context_size=context_size, blank_id=blank_id
)
decoder_input = torch.tensor(decoder_input, device=device)
# decoder_input is of shape (num_active_items, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
# decoder_output is of shape (num_active_items, 1, decoder_output_dim)
current_encoder_out = current_encoder_out.expand(
decoder_out.size(0), 1, -1
)
logits = model.joiner(
current_encoder_out,
decoder_out,
encoder_out_len.expand(decoder_out.size(0)),
decoder_out_len.expand(decoder_out.size(0)),
)
# logits is of shape (num_active_items, vocab_size)
log_probs = logits.log_softmax(dim=-1).tolist()
for i, item in enumerate(A):
if (T - 1 - t) >= (U - item.pos_u):
# horizontal transition
new_item = AlignItem(
log_prob=item.log_prob + log_probs[i][blank_id],
ys=item.ys + [blank_id],
pos_u=item.pos_u,
)
B.append(new_item)
if item.pos_u < U:
# diagonal transition
u = ys[item.pos_u]
new_item = AlignItem(
log_prob=item.log_prob + log_probs[i][u],
ys=item.ys + [u],
pos_u=item.pos_u + 1,
)
B.append(new_item)
if len(B) > beam_size:
B = B.topk(beam_size)
return B.topk(1)[0].ys