Finish preparing the inputs for conformer lm from an nbest object.

This commit is contained in:
Fangjun Kuang 2021-11-01 21:34:22 +08:00
parent 1b9e4f0fea
commit 3441634f34
2 changed files with 178 additions and 17 deletions

View File

@ -42,6 +42,28 @@ import torch
from icefall.decode import Nbest from icefall.decode import Nbest
def make_key_padding_mask(lengths: torch.Tensor):
"""
TODO: add documentation
>>> make_key_padding_mask(torch.tensor([3, 1, 4]))
tensor([[False, False, False, True],
[False, True, True, True],
[False, False, False, False]])
"""
assert lengths.dim() == 1
bs = lengths.numel()
max_len = lengths.max().item()
device = lengths.device
seq_range = torch.arange(0, max_len, device=device)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def concat( def concat(
ragged: k2.RaggedTensor, value: int, direction: str ragged: k2.RaggedTensor, value: int, direction: str
) -> k2.RaggedTensor: ) -> k2.RaggedTensor:
@ -77,13 +99,12 @@ def concat(
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}" assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
pad_values = torch.full( pad_values = torch.full(
size=(ragged.tot_size(0),), size=(ragged.tot_size(0), 1),
fill_value=value, fill_value=value,
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
pad_shape = k2.ragged.regular_ragged_shape(ragged.tot_size(0), 1).to(device) pad = k2.RaggedTensor(pad_values)
pad = k2.RaggedTensor(pad_shape, pad_values)
if direction == "left": if direction == "left":
ans = k2.ragged.cat([pad, ragged], axis=1) ans = k2.ragged.cat([pad, ragged], axis=1)
@ -233,10 +254,10 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
return k2.RaggedTensor(repeated).to(tokens.device) return k2.RaggedTensor(repeated).to(tokens.device)
def compute_alignments( def compute_alignment(
tokens: k2.RaggedTensor, tokens: k2.RaggedTensor,
shape: k2.RaggedShape, shape: k2.RaggedShape,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> k2.Fsa:
""" """
TODO: Add documentation. TODO: Add documentation.
@ -263,9 +284,90 @@ def compute_alignments(
return alignment return alignment
def prepare_conformer_lm_inputs(
alignment: k2.Fsa,
bos_id: int,
eos_id: int,
blank_id: int,
unmasked_weight: float = 0.25,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]:
"""
TODO: add documentation.
Args:
alignments:
It is computed by :func:`compute_alignment`
"""
# alignment.arcs.shape has axes [fsa][state][arc]
# we remove axis 1, i.e., state, here
labels_shape = alignment.arcs.shape().remove_axis(1)
masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous())
masked_src = masked_src.remove_values_eq(-1)
bos_masked_src = add_bos(masked_src, bos_id=bos_id)
bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id)
bos_masked_src_eos_pad = bos_masked_src_eos.pad(
mode="constant", padding_value=blank_id
)
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
src = src.remove_values_eq(-1)
bos_src = add_bos(src, bos_id=bos_id)
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
# TODO: Do we need to remove 0s from tgt ?
tgt = tgt.remove_values_eq(-1)
tgt_eos = add_eos(tgt, eos_id=eos_id)
# add a blank here since tgt_eos does not start with bos
# assume blank id is 0
tgt_eos = add_eos(tgt_eos, eos_id=blank_id)
row_splits = tgt_eos.shape.row_splits(1)
lengths = row_splits[1:] - row_splits[:-1]
src_key_padding_mask = make_key_padding_mask(lengths)
tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id)
weight = torch.full(
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
fill_value=1,
dtype=torch.float32,
)
# find unmasked positions
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
weight[unmasked_positions] = unmasked_weight
# set weights for paddings
weight[src_key_padding_mask[:, 1:]] = 0
zeros = torch.zeros(weight.size(0), 1).to(weight)
weight = torch.cat((weight, zeros), dim=1)
# all other positions are assumed to be masked and
# have the default weight 1
return (
bos_masked_src_eos_pad,
bos_src_eos_pad,
tgt_eos_pad,
src_key_padding_mask,
weight,
)
def conformer_lm_rescore( def conformer_lm_rescore(
nbest: Nbest, nbest: Nbest,
model: torch.nn.Module, model: torch.nn.Module,
bos_id: int,
eos_id: int,
blank_id: int,
unmasked_weight: float = 0.25,
# TODO: add other arguments if needed # TODO: add other arguments if needed
) -> k2.RaggedTensor: ) -> k2.RaggedTensor:
"""Rescore an Nbest object with a conformer_lm model. """Rescore an Nbest object with a conformer_lm model.
@ -281,4 +383,28 @@ def conformer_lm_rescore(
contained in the nbest. Its shape equals to `nbest.shape`. contained in the nbest. Its shape equals to `nbest.shape`.
""" """
assert hasattr(nbest.fsa, "tokens") assert hasattr(nbest.fsa, "tokens")
# TODO: utt_path_shape = nbest.shape
# nbest.fsa.arcs.shape() has axes [path][state][arc]
# We remove the state axis here
path_token_shape = nbest.fsa.arcs.shape().remove_axis(1)
path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens)
path_token = path_token.remove_values_leq(0)
alignment = compute_alignment(path_token, utt_path_shape)
(
masked_src,
src,
tgt,
src_key_padding_mask,
weight,
) = prepare_conformer_lm_inputs(
alignment,
bos_id=bos_id,
eos_id=eos_id,
blank_id=blank_id,
unmasked_weight=unmasked_weight,
)
return masked_src, src, tgt, src_key_padding_mask, weight
# TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight
# to the given model

View File

@ -17,13 +17,16 @@
import k2 import k2
import torch import torch
from icefall.decode import Nbest
from icefall.lm.rescore import ( from icefall.lm.rescore import (
add_bos, add_bos,
add_eos, add_eos,
make_repeat_map, compute_alignment,
conformer_lm_rescore,
make_hyp_to_ref_map, make_hyp_to_ref_map,
make_repeat, make_repeat,
compute_alignments, make_repeat_map,
prepare_conformer_lm_inputs,
) )
@ -103,20 +106,51 @@ def test_make_repeat():
assert str(b) == str(expected) assert str(b) == str(expected)
def test_compute_alignments(): def test_compute_alignment():
# fmt: off # fmt: off
tokens = k2.RaggedTensor([ tokens = k2.RaggedTensor([
# utt 0 # utt 0
[1, 3], [8], [2], [1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2],
# utt 1 # utt 1
[1, 5], [9], [2, 3], [2],
]) ])
shape = k2.RaggedShape('[[x x x] [x x]]')
# fmt: on # fmt: on
alignment = compute_alignments(tokens, shape) shape = k2.RaggedShape("[[x x x] [x x]]")
print("maksed_src:", alignment.labels) alignment = compute_alignment(tokens, shape)
print("src:", alignment.hyp_labels) (
print("tgt:", alignment.ref_labels) masked_src,
src,
tgt,
src_key_padding_mask,
weight,
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
# print("masked src", masked_src)
# print("src", src)
# print("tgt", tgt)
# print("src_key_padding_mask", src_key_padding_mask)
# print("weight", weight)
def test_conformer_lm_rescore():
path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5])
path01 = k2.linear_fsa([1, 0, 5, 0])
path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2])
path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2])
path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3])
fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12])
fsa.tokens = fsa.labels.clone()
shape = k2.RaggedShape("[[x x] [x x x]]")
nbest = Nbest(fsa, shape)
masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore(
nbest, model=None, bos_id=10, eos_id=20, blank_id=0
)
print("masked src", masked_src)
print("src", src)
print("tgt", tgt)
print("src_key_padding_mask", src_key_padding_mask)
print("weight", weight)
def main(): def main():
@ -126,7 +160,8 @@ def main():
test_make_repeat_map() test_make_repeat_map()
test_make_hyp_to_ref_map() test_make_hyp_to_ref_map()
test_make_repeat() test_make_repeat()
test_compute_alignments() test_compute_alignment()
test_conformer_lm_rescore()
if __name__ == "__main__": if __name__ == "__main__":