diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 860b31b08..1ffef3bff 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -498,6 +498,8 @@ def decode_dataset( results = defaultdict(list) for batch_idx, batch in enumerate(dl): + # if batch_idx > 10: + # break texts = batch["supervisions"]["text"] hyps_dict = decode_one_batch( @@ -539,7 +541,7 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - if params.method == "attention-decoder": + if params.method in ("attention-decoder", "conformer-lm"): # Set it to False since there are too many logs. enable_log = False else: @@ -591,7 +593,7 @@ def main(): params = get_params() params.update(vars(args)) - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + setup_logger(f"{params.exp_dir}/log-{params.method}-2/log-decode") logging.info("Decoding started") logging.info(params) @@ -714,9 +716,15 @@ def main(): model.eval() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + model.device = device if params.method == "conformer-lm": logging.info("Loading conformer lm model") + assert torch.cuda.device_count() > 1, f"{torch.cuda.device_count()}" + + # We use a second GPU for masked LM model as it causes OOM + # with 1 GPU + device2 = torch.device("cuda", 1) # Note: If the parameters does not match # the one used to save the checkpoint, it will # throw while calling `load_state_dict`. @@ -740,10 +748,12 @@ def main(): f"{params.conformer_lm_exp_dir}/epoch-{i}.pt" ) logging.info(f"averaging {filenames}") - masked_lm_model.to(device) + masked_lm_model.to(device2) masked_lm_model.load_state_dict( - average_checkpoints(filenames, device=device) + average_checkpoints(filenames, device=device2) ) + masked_lm_model.to(device2) + masked_lm_model.device = device2 else: masked_lm_model = None @@ -756,6 +766,8 @@ def main(): # test_sets = ["test-clean", "test-other"] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + # if test_set == "test-clean": + # continue results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/icefall/decode.py b/icefall/decode.py index bc59a96d6..17010ec37 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -870,6 +870,7 @@ def rescore_with_attention_decoder( ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] else: ngram_lm_scale_list = [ngram_lm_scale] @@ -877,6 +878,7 @@ def rescore_with_attention_decoder( attention_scale_list = [0.01, 0.05, 0.08] attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] else: attention_scale_list = [attention_scale] @@ -987,55 +989,79 @@ def rescore_with_conformer_lm( tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) tokens = tokens.remove_values_leq(0) - alignment = compute_alignment(tokens, nbest.shape) - ( - masked_src_symbols, - src_symbols, - tgt_symbols, - src_key_padding_mask, - tgt_weights, - ) = prepare_conformer_lm_inputs( - alignment, - bos_id=sos_id, - eos_id=eos_id, - blank_id=blank_id, - unmasked_weight=0.0, + device = model.device + + # import pdb + # + # pdb.set_trace() + path_per_utt = ( + nbest.shape.row_splits(1)[1:] - nbest.shape.row_splits(1)[:-1] ) + logging.info(f"path per utt: {path_per_utt}") + if 1 not in path_per_utt: + device2 = masked_lm_model.device - masked_src_symbols = masked_src_symbols.to(torch.int64) - src_symbols = src_symbols.to(torch.int64) - tgt_symbols = tgt_symbols.to(torch.int64) + alignment = compute_alignment( + tokens.to(device2), nbest.shape.to(device2) + ) + tgt_ll_list = [] + for label_name in ["ref_labels", "hyp_labels"]: + ( + masked_src_symbols, + src_symbols, + tgt_symbols, + src_key_padding_mask, + tgt_weights, + ) = prepare_conformer_lm_inputs( + alignment, + bos_id=sos_id, + eos_id=eos_id, + blank_id=blank_id, + src_label_name=label_name, + unmasked_weight=0.0, + ) - masked_lm_memory, masked_lm_pos_emb = masked_lm_model( - masked_src_symbols, src_key_padding_mask - ) + masked_src_symbols = masked_src_symbols.to(torch.int64) + src_symbols = src_symbols.to(torch.int64) + tgt_symbols = tgt_symbols.to(torch.int64) - tgt_nll = masked_lm_model.decoder_nll( - masked_lm_memory, - masked_lm_pos_emb, - src_symbols, - tgt_symbols, - src_key_padding_mask, - ) + masked_lm_memory, masked_lm_pos_emb = masked_lm_model( + masked_src_symbols, src_key_padding_mask + ) - # nll means negative log-likelihood - # ll means log-likelihood - tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1) + tgt_nll = masked_lm_model.decoder_nll( + masked_lm_memory, + masked_lm_pos_emb, + src_symbols, + tgt_symbols, + src_key_padding_mask, + ) - # Note: log-likelihood for those pairs that have identical src/tgt are 0 - # since their tgt_weights is 0 + # nll means negative log-likelihood + # ll means log-likelihood + tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1) - # TODO(fangjun): Add documentation about why we do the following - tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1)) - tgt_ll_shape = k2.ragged.create_ragged_shape2( - row_splits=None, - row_ids=tgt_ll_shape_row_ids, - cached_tot_size=tgt_ll_shape_row_ids.numel(), - ) - ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll) + tgt_ll_list.append(tgt_ll) - ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0) - masked_lm_scores = ragged_tgt_ll.max() + # tgt_ll = tgt_ll_list[1] - tgt_ll_list[0] # wer: 2.61 + tgt_ll = tgt_ll_list[0] - tgt_ll_list[1] + + # TODO(fangjun): Add documentation about why we do the following + tgt_ll_shape_row_ids = make_hyp_to_ref_map( + nbest.shape.row_splits(1).to(device2) + ) + tgt_ll_shape = k2.ragged.create_ragged_shape2( + row_splits=None, + row_ids=tgt_ll_shape_row_ids, + cached_tot_size=tgt_ll_shape_row_ids.numel(), + ) + ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll) + + ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0) + masked_lm_scores = ragged_tgt_ll.max().to(device) + else: + logging.warning(f"Disable masked lm. path per utt is: {path_per_utt}") + masked_lm_scores = torch.zeros_like(am_scores.values) # TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly. token_ids = tokens.tolist() @@ -1056,6 +1082,7 @@ def rescore_with_conformer_lm( ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] else: ngram_lm_scale_list = [ngram_lm_scale] @@ -1063,6 +1090,7 @@ def rescore_with_conformer_lm( attention_scale_list = [0.01, 0.05, 0.08] attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] else: attention_scale_list = [attention_scale] @@ -1070,6 +1098,7 @@ def rescore_with_conformer_lm( masked_lm_scale_list = [0.01, 0.05, 0.08] masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] + masked_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0] else: masked_lm_scale_list = [masked_lm_scale] diff --git a/icefall/lm/rescore.py b/icefall/lm/rescore.py index e9984f5f6..2de9f735a 100644 --- a/icefall/lm/rescore.py +++ b/icefall/lm/rescore.py @@ -168,7 +168,7 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor): >>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32) >>> make_hyp_to_ref_map(row_splits) - tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4], dtype=torch.int32) + tensor([0, 0, 1, 1, 2, 2, 3, 4], dtype=torch.int32) """ device = row_splits.device @@ -180,12 +180,12 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor): # Explanation of the following operations # assume size is 3, offset is 2 # torch.arange() + offset is [2, 3, 4] - # expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]] - # t() is [[2, 2, 2], [3, 3, 3], [4, 4, 4]] - # reshape() is [2, 2, 2, 3, 3, 3, 4, 4, 4] + # expand() is [[2, 3, 4], [2, 3, 4]] + # t() is [[2, 2], [3, 3], [4, 4]] + # reshape() is [2, 2, 3, 3, 4, 4] map_tensor = ( (torch.arange(size, dtype=torch.int32, device=device) + offset) - .expand(size, size) + .expand(size - 1, size) .t() .reshape(-1) ) @@ -219,6 +219,12 @@ def make_repeat_map(row_splits: torch.Tensor): .expand(size, size) .reshape(-1) ) + diag_offset = torch.arange(size, device=device) * (size + 1) + # remove diagonal elements + map_tensor[diag_offset] = -1 + map_tensor = map_tensor[map_tensor != -1] + # In the above example, map_tensor becomes + # [3, 4, 2, 4, 2, 3] map_tensor_list.append(map_tensor) return torch.cat(map_tensor_list) @@ -233,25 +239,17 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor: [path1 path2 path3] [path1 path2 path3] [path1 path2 path3] >>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ]) - >>> tokens - [ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ] - >>> make_repeat(tokens) - [ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa + >>> tokens.to_str_simple() + 'RaggedTensor([[[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]]], dtype=torch.int32)' + >>> make_repeat(tokens).to_str_simple() + 'RaggedTensor([[[4, 5], [9], [1, 2, 3], [9], [1, 2, 3], [4, 5]], [[10, 1], [5, 8]]], dtype=torch.int32)' # noqa TODO: Add documentation. """ assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}" - if True: - indexes = make_repeat_map(tokens.shape.row_splits(1)) - return tokens.index(axis=1, indexes=indexes)[0] - else: - # This branch produces the same result as the above branch. - # It's more readable. Will remove it later. - repeated = [] - for p in tokens.tolist(): - repeated.append(p * len(p)) - return k2.RaggedTensor(repeated).to(tokens.device) + indexes = make_repeat_map(tokens.shape.row_splits(1)) + return tokens.index(axis=1, indexes=indexes)[0] def compute_alignment( @@ -289,7 +287,8 @@ def prepare_conformer_lm_inputs( bos_id: int, eos_id: int, blank_id: int, - unmasked_weight: float = 0.25, + src_label_name: str, + unmasked_weight: float = 0.0, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor ]: @@ -299,7 +298,18 @@ def prepare_conformer_lm_inputs( Args: alignments: It is computed by :func:`compute_alignment` + bos_id: + ID of the bos symbol. + eos_id: + ID of the eos symbol. + blank_id: + ID of the blank symbol. + src_label_name: + The name of the attribute from `alignment` that will be used for `src`. + `tgt` is a shift version of `src`. Valid values are: "ref_labels" + and "hyp_labels". """ + assert src_label_name in ("ref_labels", "hyp_labels") device = alignment.device # alignment.arcs.shape has axes [fsa][state][arc] # we remove axis 1, i.e., state, here @@ -313,13 +323,13 @@ def prepare_conformer_lm_inputs( mode="constant", padding_value=blank_id ) - src = k2.RaggedTensor(labels_shape, alignment.hyp_labels) + src = k2.RaggedTensor(labels_shape, getattr(alignment, src_label_name)) src = src.remove_values_eq(-1) bos_src = add_bos(src, bos_id=bos_id) bos_src_eos = add_eos(bos_src, eos_id=eos_id) bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id) - tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels) + tgt = k2.RaggedTensor(labels_shape, getattr(alignment, src_label_name)) # TODO: Do we need to remove 0s from tgt ? tgt = tgt.remove_values_eq(-1) tgt_eos = add_eos(tgt, eos_id=eos_id) @@ -342,7 +352,7 @@ def prepare_conformer_lm_inputs( ) # find unmasked positions - unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1] + unmasked_positions = bos_masked_src_eos_pad[:, 1:] != 0 weight[unmasked_positions] = unmasked_weight # set weights for paddings diff --git a/test/lm/test_rescore.py b/test/lm/test_rescore.py index 49016d1f5..ac77ecb14 100755 --- a/test/lm/test_rescore.py +++ b/test/lm/test_rescore.py @@ -69,8 +69,8 @@ def test_make_hyp_to_ref_map(): row_splits = a.shape.row_splits(1) repeat_map = make_hyp_to_ref_map(row_splits) # fmt: off - expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, - 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa + expected = torch.tensor([0, 0, 1, 1, 2, 2, 3, + 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]).to(repeat_map) # noqa # fmt: on assert torch.all(torch.eq(repeat_map, expected)) @@ -80,9 +80,9 @@ def test_make_repeat_map(): row_splits = a.shape.row_splits(1) repeat_map = make_repeat_map(row_splits) # fmt: off - expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, - 3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa - 3, 4, 5, 6]).to(repeat_map) # noqa + expected = torch.tensor([1, 2, 0, 2, 0, 1, + 4, 5, 6, 3, 5, 6, 3, 4, 6, # noqa + 3, 4, 5]).to(repeat_map) # noqa # fmt: on assert torch.all(torch.eq(repeat_map, expected)) @@ -95,11 +95,11 @@ def test_make_repeat(): ]) b = make_repeat(a) expected = k2.RaggedTensor([ - [[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]], - [[1, 2, 3, 4], [2], [], [9, 10, 11], - [1, 2, 3, 4], [2], [], [9, 10, 11], - [1, 2, 3, 4], [2], [], [9, 10, 11], - [1, 2, 3, 4], [2], [], [9, 10, 11]], + [[2, 6], [1, 3, 5]], + [ [2], [], [9, 10, 11], # noqa + [1, 2, 3, 4], [], [9, 10, 11], # noqa + [1, 2, 3, 4], [2], [9, 10, 11], # noqa + [1, 2, 3, 4], [2], [], ], # noqa ]) # fmt: on assert str(b) == str(expected) @@ -116,19 +116,24 @@ def test_compute_alignment(): # fmt: on shape = k2.RaggedShape("[[x x x] [x x]]") alignment = compute_alignment(tokens, shape) + print(alignment.ref_labels) + print(alignment.hyp_labels) + print(alignment.labels) ( masked_src, src, tgt, src_key_padding_mask, weight, - ) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0) + ) = prepare_conformer_lm_inputs( + alignment, bos_id=10, eos_id=20, blank_id=0, src_label_name="hyp_labels" + ) - # print("masked src", masked_src) - # print("src", src) - # print("tgt", tgt) - # print("src_key_padding_mask", src_key_padding_mask) - # print("weight", weight) + print("masked src", masked_src) + print("src", src) + print("tgt", tgt) + print("src_key_padding_mask", src_key_padding_mask) + print("weight", weight) def main():