diff --git a/icefall/decode.py b/icefall/decode.py index 17010ec37..736e5200b 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -20,11 +20,7 @@ from typing import Dict, List, Optional, Union import k2 import torch -from icefall.lm.rescore import ( - compute_alignment, - make_hyp_to_ref_map, - prepare_conformer_lm_inputs, -) +from icefall.lm.rescore import compute_alignment, prepare_conformer_lm_inputs from icefall.utils import get_texts @@ -989,8 +985,6 @@ def rescore_with_conformer_lm( tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) tokens = tokens.remove_values_leq(0) - device = model.device - # import pdb # # pdb.set_trace() @@ -998,70 +992,51 @@ def rescore_with_conformer_lm( nbest.shape.row_splits(1)[1:] - nbest.shape.row_splits(1)[:-1] ) logging.info(f"path per utt: {path_per_utt}") - if 1 not in path_per_utt: - device2 = masked_lm_model.device - alignment = compute_alignment( - tokens.to(device2), nbest.shape.to(device2) + device2 = masked_lm_model.device + + alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2)) + tgt_ll_list = [] + for label_name in ["ref_labels", "hyp_labels"]: + ( + masked_src_symbols, + src_symbols, + tgt_symbols, + src_key_padding_mask, + tgt_weights, + ) = prepare_conformer_lm_inputs( + alignment, + bos_id=sos_id, + eos_id=eos_id, + blank_id=blank_id, + src_label_name=label_name, + unmasked_weight=0.0, ) - tgt_ll_list = [] - for label_name in ["ref_labels", "hyp_labels"]: - ( - masked_src_symbols, - src_symbols, - tgt_symbols, - src_key_padding_mask, - tgt_weights, - ) = prepare_conformer_lm_inputs( - alignment, - bos_id=sos_id, - eos_id=eos_id, - blank_id=blank_id, - src_label_name=label_name, - unmasked_weight=0.0, - ) - masked_src_symbols = masked_src_symbols.to(torch.int64) - src_symbols = src_symbols.to(torch.int64) - tgt_symbols = tgt_symbols.to(torch.int64) + masked_src_symbols = masked_src_symbols.to(torch.int64) + src_symbols = src_symbols.to(torch.int64) + tgt_symbols = tgt_symbols.to(torch.int64) - masked_lm_memory, masked_lm_pos_emb = masked_lm_model( - masked_src_symbols, src_key_padding_mask - ) - - tgt_nll = masked_lm_model.decoder_nll( - masked_lm_memory, - masked_lm_pos_emb, - src_symbols, - tgt_symbols, - src_key_padding_mask, - ) - - # nll means negative log-likelihood - # ll means log-likelihood - tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1) - - tgt_ll_list.append(tgt_ll) - - # tgt_ll = tgt_ll_list[1] - tgt_ll_list[0] # wer: 2.61 - tgt_ll = tgt_ll_list[0] - tgt_ll_list[1] - - # TODO(fangjun): Add documentation about why we do the following - tgt_ll_shape_row_ids = make_hyp_to_ref_map( - nbest.shape.row_splits(1).to(device2) + masked_lm_memory, masked_lm_pos_emb = masked_lm_model( + masked_src_symbols, src_key_padding_mask ) - tgt_ll_shape = k2.ragged.create_ragged_shape2( - row_splits=None, - row_ids=tgt_ll_shape_row_ids, - cached_tot_size=tgt_ll_shape_row_ids.numel(), - ) - ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll) - ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0) - masked_lm_scores = ragged_tgt_ll.max().to(device) - else: - logging.warning(f"Disable masked lm. path per utt is: {path_per_utt}") - masked_lm_scores = torch.zeros_like(am_scores.values) + tgt_nll = masked_lm_model.decoder_nll( + masked_lm_memory, + masked_lm_pos_emb, + src_symbols, + tgt_symbols, + src_key_padding_mask, + ) + + # nll means negative log-likelihood + # ll means log-likelihood + tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1) + + tgt_ll_list.append(tgt_ll) + + # hyp - ref + masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0] # TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly. token_ids = tokens.tolist() diff --git a/icefall/lm/rescore.py b/icefall/lm/rescore.py index 466b6fa6f..dbac1526c 100644 --- a/icefall/lm/rescore.py +++ b/icefall/lm/rescore.py @@ -17,16 +17,22 @@ """ This file contains rescoring code for NN LMs, e.g., conformer LM. -Support an utterance has 3 paths: - (a, b, c) +Suppose that an utterance has 3 paths: + (a, b, c, d, e) + and we want to use a masked conformer LM to assign a likelihood to each path. The following shows the steps: -(1) Select path pairs: - (a, b), (a, c) - (b, a), (b, c) - (c, a), (c, b) +(1) Select "a" as the reference path. Note: For ease of implementation, + we always select the first path as the reference one. + +(2) We have the following path pairs: + + (a, a) (a, b), (a, c), (a, d), (a, e) + + Note: Even if we know the likelihood for the pair (a, a) is always 0, + we still list it here for ease of implementation. (2) For each pair, e.g., for the pair (a, b), @@ -39,20 +45,27 @@ The following shows the steps: (iv) Use "b" as "src" and its shifted version as "tgt". We can get another likelihood value, denoted as "ab_other" -So for the path pair (a, b), (a, c), (b, a), (b, c), (c, a), and (c, b), +So for the path pair (a, a), (a, b), (a, c), (a, d), and (a, e), we can get the following log-likelihood values, viewed as two tensors: - self = [ab_self, ac_self, ba_self, bc_self, ca_self, cb_self] + self = [aa_self, ab_self, ac_self, ad_self, ae_self] - other = [ab_other, ac_other, ba_other, bc_other, ca_other, cb_other] + other = [aa_other, ab_other, ac_other, ad_other, ae_other] - Compute the difference the two tensors: + Compute the difference between the two tensors: - self - other = [ab_self - ab_other, ac_self - ac_other, ...] + other - self = [aa_other - aa_self, ab_other - ab_self, + ac_other - ac_self, ... ] - The log-likelihood for path a is : max(ab_self - ab_other, ac_self - ac_other) - The log-likelihood for path b is : max(ba_self - ba_other, bc_self - bc_other) - The log-likelihood for path c is : max(ca_self - ca_other, cb_self - cb_other) + The log-likelihood for path a is : 0 + The log-likelihood for path b is : ab_other - ab_self + The log-likelihood for path c is : ac_other - ac_self + The log-likelihood for path d is : ad_other - ad_self + The log-likelihood for path e is : ae_other - ae_self + +Note: "ab_other - ab_self" can be interpreted as + + log P(b) - log P(a) """ from typing import Tuple @@ -181,14 +194,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: return concat(ragged, eos_id, direction="right") -def make_hyp_to_ref_map(row_splits: torch.Tensor): +def make_hyp_to_ref_map(row_splits: torch.Tensor) -> torch.Tensor: """ - TODO: Add documentation. + TODO: Add documentation >>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32) >>> make_hyp_to_ref_map(row_splits) - tensor([0, 0, 1, 1, 2, 2, 3, 4], dtype=torch.int32) - + tensor([0, 0, 0, 3, 3], dtype=torch.int32) """ device = row_splits.device sizes = (row_splits[1:] - row_splits[:-1]).tolist() @@ -198,78 +210,15 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor): for size, offset in zip(sizes, offsets): # Explanation of the following operations # assume size is 3, offset is 2 - # torch.arange() + offset is [2, 3, 4] - # expand() is [[2, 3, 4], [2, 3, 4]] - # t() is [[2, 2], [3, 3], [4, 4]] - # reshape() is [2, 2, 3, 3, 4, 4] + # torch.zeros() + offset is [2, 2, 2] map_tensor = ( - (torch.arange(size, dtype=torch.int32, device=device) + offset) - .expand(size - 1, size) - .t() - .reshape(-1) + torch.zeros(size, dtype=torch.int32, device=device) + offset ) map_tensor_list.append(map_tensor) return torch.cat(map_tensor_list) -def make_repeat_map(row_splits: torch.Tensor): - """ - TODO: Add documentation. - - >>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32) - >>> make_repeat_map(row_splits) - tensor([1, 2, 0, 2, 0, 1, 4, 3], dtype=torch.int32) - - """ - device = row_splits.device - sizes = (row_splits[1:] - row_splits[:-1]).tolist() - offsets = row_splits[:-1] - - map_tensor_list = [] - for size, offset in zip(sizes, offsets): - # Explanation of the following operations - # assume size is 3, offset is 2 - # torch.arange() + offset is [2, 3, 4] - # expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]] - # reshape() is [2, 3, 4, 2, 3, 4, 2, 3, 4] - map_tensor = ( - (torch.arange(size, dtype=torch.int32, device=device) + offset) - .expand(size, size) - .reshape(-1) - ) - diag_offset = torch.arange(size, device=device) * (size + 1) - # remove diagonal elements - map_tensor[diag_offset] = -1 - map_tensor = map_tensor[map_tensor != -1] - # In the above example, map_tensor becomes - # [3, 4, 2, 4, 2, 3] - map_tensor_list.append(map_tensor) - - return torch.cat(map_tensor_list) - - -def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor: - """Repeat paths in an utterance. - - For instance, if an utterance contains 3 paths: [path1 path2 path3], - after repeating, this utterance will contain 6 paths: - [path2 path3] [path1 path3] [path1 path2] - - >>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ]) - >>> tokens.to_str_simple() - 'RaggedTensor([[[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]]], dtype=torch.int32)' - >>> make_repeat(tokens).to_str_simple() - 'RaggedTensor([[[4, 5], [9], [1, 2, 3], [9], [1, 2, 3], [4, 5]], [[10, 1], [5, 8]]], dtype=torch.int32)' # noqa - - TODO: Add documentation. - - """ - assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}" - indexes = make_repeat_map(tokens.shape.row_splits(1)) - return tokens.index(axis=1, indexes=indexes)[0] - - def compute_alignment( tokens: k2.RaggedTensor, shape: k2.RaggedShape, @@ -285,15 +234,13 @@ def compute_alignment( """ assert tokens.tot_size(0) == shape.tot_size(1) device = tokens.device - utt_path_shape = shape.compose(tokens.shape) - utt_path_token = k2.RaggedTensor(utt_path_shape, tokens.values) - utt_path_token_repeated = make_repeat(utt_path_token) - path_token_repeated = utt_path_token_repeated.remove_axis(0) + + hyps = k2.levenshtein_graph(tokens, device=device) refs = k2.levenshtein_graph(tokens, device=device) - hyps = k2.levenshtein_graph(path_token_repeated, device=device) - hyp_to_ref_map = make_hyp_to_ref_map(utt_path_shape.row_splits(1)) + hyp_to_ref_map = make_hyp_to_ref_map(shape.row_splits(1)) + alignment = k2.levenshtein_alignment( refs=refs, hyps=hyps, hyp_to_ref_map=hyp_to_ref_map ) diff --git a/test/lm/test_rescore.py b/test/lm/test_rescore.py index ac77ecb14..5855f2987 100755 --- a/test/lm/test_rescore.py +++ b/test/lm/test_rescore.py @@ -22,8 +22,6 @@ from icefall.lm.rescore import ( add_eos, compute_alignment, make_hyp_to_ref_map, - make_repeat, - make_repeat_map, prepare_conformer_lm_inputs, ) @@ -67,42 +65,9 @@ def test_pad(): def test_make_hyp_to_ref_map(): a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]]) row_splits = a.shape.row_splits(1) - repeat_map = make_hyp_to_ref_map(row_splits) - # fmt: off - expected = torch.tensor([0, 0, 1, 1, 2, 2, 3, - 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]).to(repeat_map) # noqa - # fmt: on - assert torch.all(torch.eq(repeat_map, expected)) - - -def test_make_repeat_map(): - a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]]) - row_splits = a.shape.row_splits(1) - repeat_map = make_repeat_map(row_splits) - # fmt: off - expected = torch.tensor([1, 2, 0, 2, 0, 1, - 4, 5, 6, 3, 5, 6, 3, 4, 6, # noqa - 3, 4, 5]).to(repeat_map) # noqa - # fmt: on - assert torch.all(torch.eq(repeat_map, expected)) - - -def test_make_repeat(): - # fmt: off - a = k2.RaggedTensor([ - [[1, 3, 5], [2, 6]], - [[1, 2, 3, 4], [2], [], [9, 10, 11]], - ]) - b = make_repeat(a) - expected = k2.RaggedTensor([ - [[2, 6], [1, 3, 5]], - [ [2], [], [9, 10, 11], # noqa - [1, 2, 3, 4], [], [9, 10, 11], # noqa - [1, 2, 3, 4], [2], [9, 10, 11], # noqa - [1, 2, 3, 4], [2], [], ], # noqa - ]) - # fmt: on - assert str(b) == str(expected) + hyp_to_ref_map = make_hyp_to_ref_map(row_splits) + expected = torch.tensor([0, 0, 0, 3, 3, 3, 3]).to(hyp_to_ref_map) + assert torch.all(torch.eq(hyp_to_ref_map, expected)) def test_compute_alignment(): @@ -140,9 +105,7 @@ def main(): test_add_bos() test_add_eos() test_pad() - test_make_repeat_map() test_make_hyp_to_ref_map() - test_make_repeat() test_compute_alignment()