From 3441634f34691721302e37337edc4c5acdd873c6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 1 Nov 2021 21:34:22 +0800 Subject: [PATCH] Finish preparing the inputs for conformer lm from an nbest object. --- icefall/lm/rescore.py | 138 ++++++++++++++++++++++++++++++++++++++-- test/lm/test_rescore.py | 57 +++++++++++++---- 2 files changed, 178 insertions(+), 17 deletions(-) diff --git a/icefall/lm/rescore.py b/icefall/lm/rescore.py index fa15cc621..be2b3c929 100644 --- a/icefall/lm/rescore.py +++ b/icefall/lm/rescore.py @@ -42,6 +42,28 @@ import torch from icefall.decode import Nbest +def make_key_padding_mask(lengths: torch.Tensor): + """ + TODO: add documentation + + >>> make_key_padding_mask(torch.tensor([3, 1, 4])) + tensor([[False, False, False, True], + [False, True, True, True], + [False, False, False, False]]) + """ + assert lengths.dim() == 1 + + bs = lengths.numel() + max_len = lengths.max().item() + device = lengths.device + seq_range = torch.arange(0, max_len, device=device) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + + seq_length_expand = lengths.unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + return mask + + def concat( ragged: k2.RaggedTensor, value: int, direction: str ) -> k2.RaggedTensor: @@ -77,13 +99,12 @@ def concat( assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}" pad_values = torch.full( - size=(ragged.tot_size(0),), + size=(ragged.tot_size(0), 1), fill_value=value, device=device, dtype=dtype, ) - pad_shape = k2.ragged.regular_ragged_shape(ragged.tot_size(0), 1).to(device) - pad = k2.RaggedTensor(pad_shape, pad_values) + pad = k2.RaggedTensor(pad_values) if direction == "left": ans = k2.ragged.cat([pad, ragged], axis=1) @@ -233,10 +254,10 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor: return k2.RaggedTensor(repeated).to(tokens.device) -def compute_alignments( +def compute_alignment( tokens: k2.RaggedTensor, shape: k2.RaggedShape, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> k2.Fsa: """ TODO: Add documentation. @@ -263,9 +284,90 @@ def compute_alignments( return alignment +def prepare_conformer_lm_inputs( + alignment: k2.Fsa, + bos_id: int, + eos_id: int, + blank_id: int, + unmasked_weight: float = 0.25, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + """ + TODO: add documentation. + + Args: + alignments: + It is computed by :func:`compute_alignment` + """ + # alignment.arcs.shape has axes [fsa][state][arc] + # we remove axis 1, i.e., state, here + labels_shape = alignment.arcs.shape().remove_axis(1) + + masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous()) + masked_src = masked_src.remove_values_eq(-1) + bos_masked_src = add_bos(masked_src, bos_id=bos_id) + bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id) + bos_masked_src_eos_pad = bos_masked_src_eos.pad( + mode="constant", padding_value=blank_id + ) + + src = k2.RaggedTensor(labels_shape, alignment.hyp_labels) + src = src.remove_values_eq(-1) + bos_src = add_bos(src, bos_id=bos_id) + bos_src_eos = add_eos(bos_src, eos_id=eos_id) + bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id) + + tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels) + # TODO: Do we need to remove 0s from tgt ? + tgt = tgt.remove_values_eq(-1) + tgt_eos = add_eos(tgt, eos_id=eos_id) + + # add a blank here since tgt_eos does not start with bos + # assume blank id is 0 + tgt_eos = add_eos(tgt_eos, eos_id=blank_id) + + row_splits = tgt_eos.shape.row_splits(1) + lengths = row_splits[1:] - row_splits[:-1] + src_key_padding_mask = make_key_padding_mask(lengths) + + tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id) + + weight = torch.full( + (tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1), + fill_value=1, + dtype=torch.float32, + ) + + # find unmasked positions + unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1] + weight[unmasked_positions] = unmasked_weight + + # set weights for paddings + weight[src_key_padding_mask[:, 1:]] = 0 + zeros = torch.zeros(weight.size(0), 1).to(weight) + + weight = torch.cat((weight, zeros), dim=1) + + # all other positions are assumed to be masked and + # have the default weight 1 + + return ( + bos_masked_src_eos_pad, + bos_src_eos_pad, + tgt_eos_pad, + src_key_padding_mask, + weight, + ) + + def conformer_lm_rescore( nbest: Nbest, model: torch.nn.Module, + bos_id: int, + eos_id: int, + blank_id: int, + unmasked_weight: float = 0.25, # TODO: add other arguments if needed ) -> k2.RaggedTensor: """Rescore an Nbest object with a conformer_lm model. @@ -281,4 +383,28 @@ def conformer_lm_rescore( contained in the nbest. Its shape equals to `nbest.shape`. """ assert hasattr(nbest.fsa, "tokens") - # TODO: + utt_path_shape = nbest.shape + # nbest.fsa.arcs.shape() has axes [path][state][arc] + # We remove the state axis here + path_token_shape = nbest.fsa.arcs.shape().remove_axis(1) + + path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens) + path_token = path_token.remove_values_leq(0) + + alignment = compute_alignment(path_token, utt_path_shape) + ( + masked_src, + src, + tgt, + src_key_padding_mask, + weight, + ) = prepare_conformer_lm_inputs( + alignment, + bos_id=bos_id, + eos_id=eos_id, + blank_id=blank_id, + unmasked_weight=unmasked_weight, + ) + return masked_src, src, tgt, src_key_padding_mask, weight + # TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight + # to the given model diff --git a/test/lm/test_rescore.py b/test/lm/test_rescore.py index fd7457a7e..bade45b5d 100755 --- a/test/lm/test_rescore.py +++ b/test/lm/test_rescore.py @@ -17,13 +17,16 @@ import k2 import torch +from icefall.decode import Nbest from icefall.lm.rescore import ( add_bos, add_eos, - make_repeat_map, + compute_alignment, + conformer_lm_rescore, make_hyp_to_ref_map, make_repeat, - compute_alignments, + make_repeat_map, + prepare_conformer_lm_inputs, ) @@ -103,20 +106,51 @@ def test_make_repeat(): assert str(b) == str(expected) -def test_compute_alignments(): +def test_compute_alignment(): # fmt: off tokens = k2.RaggedTensor([ # utt 0 - [1, 3], [8], [2], + [1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2], # utt 1 - [1, 5], [9], + [2, 3], [2], ]) - shape = k2.RaggedShape('[[x x x] [x x]]') # fmt: on - alignment = compute_alignments(tokens, shape) - print("maksed_src:", alignment.labels) - print("src:", alignment.hyp_labels) - print("tgt:", alignment.ref_labels) + shape = k2.RaggedShape("[[x x x] [x x]]") + alignment = compute_alignment(tokens, shape) + ( + masked_src, + src, + tgt, + src_key_padding_mask, + weight, + ) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0) + + # print("masked src", masked_src) + # print("src", src) + # print("tgt", tgt) + # print("src_key_padding_mask", src_key_padding_mask) + # print("weight", weight) + + +def test_conformer_lm_rescore(): + path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5]) + path01 = k2.linear_fsa([1, 0, 5, 0]) + path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2]) + path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2]) + path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3]) + + fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12]) + fsa.tokens = fsa.labels.clone() + shape = k2.RaggedShape("[[x x] [x x x]]") + nbest = Nbest(fsa, shape) + masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore( + nbest, model=None, bos_id=10, eos_id=20, blank_id=0 + ) + print("masked src", masked_src) + print("src", src) + print("tgt", tgt) + print("src_key_padding_mask", src_key_padding_mask) + print("weight", weight) def main(): @@ -126,7 +160,8 @@ def main(): test_make_repeat_map() test_make_hyp_to_ref_map() test_make_repeat() - test_compute_alignments() + test_compute_alignment() + test_conformer_lm_rescore() if __name__ == "__main__":