Fixes after review.

This commit is contained in:
Fangjun Kuang 2021-11-15 15:44:25 +08:00
parent 57b9c8868b
commit 878fb40a12
3 changed files with 79 additions and 194 deletions

View File

@ -20,11 +20,7 @@ from typing import Dict, List, Optional, Union
import k2
import torch
from icefall.lm.rescore import (
compute_alignment,
make_hyp_to_ref_map,
prepare_conformer_lm_inputs,
)
from icefall.lm.rescore import compute_alignment, prepare_conformer_lm_inputs
from icefall.utils import get_texts
@ -989,8 +985,6 @@ def rescore_with_conformer_lm(
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0)
device = model.device
# import pdb
#
# pdb.set_trace()
@ -998,70 +992,51 @@ def rescore_with_conformer_lm(
nbest.shape.row_splits(1)[1:] - nbest.shape.row_splits(1)[:-1]
)
logging.info(f"path per utt: {path_per_utt}")
if 1 not in path_per_utt:
device2 = masked_lm_model.device
alignment = compute_alignment(
tokens.to(device2), nbest.shape.to(device2)
device2 = masked_lm_model.device
alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2))
tgt_ll_list = []
for label_name in ["ref_labels", "hyp_labels"]:
(
masked_src_symbols,
src_symbols,
tgt_symbols,
src_key_padding_mask,
tgt_weights,
) = prepare_conformer_lm_inputs(
alignment,
bos_id=sos_id,
eos_id=eos_id,
blank_id=blank_id,
src_label_name=label_name,
unmasked_weight=0.0,
)
tgt_ll_list = []
for label_name in ["ref_labels", "hyp_labels"]:
(
masked_src_symbols,
src_symbols,
tgt_symbols,
src_key_padding_mask,
tgt_weights,
) = prepare_conformer_lm_inputs(
alignment,
bos_id=sos_id,
eos_id=eos_id,
blank_id=blank_id,
src_label_name=label_name,
unmasked_weight=0.0,
)
masked_src_symbols = masked_src_symbols.to(torch.int64)
src_symbols = src_symbols.to(torch.int64)
tgt_symbols = tgt_symbols.to(torch.int64)
masked_src_symbols = masked_src_symbols.to(torch.int64)
src_symbols = src_symbols.to(torch.int64)
tgt_symbols = tgt_symbols.to(torch.int64)
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
masked_src_symbols, src_key_padding_mask
)
tgt_nll = masked_lm_model.decoder_nll(
masked_lm_memory,
masked_lm_pos_emb,
src_symbols,
tgt_symbols,
src_key_padding_mask,
)
# nll means negative log-likelihood
# ll means log-likelihood
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
tgt_ll_list.append(tgt_ll)
# tgt_ll = tgt_ll_list[1] - tgt_ll_list[0] # wer: 2.61
tgt_ll = tgt_ll_list[0] - tgt_ll_list[1]
# TODO(fangjun): Add documentation about why we do the following
tgt_ll_shape_row_ids = make_hyp_to_ref_map(
nbest.shape.row_splits(1).to(device2)
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
masked_src_symbols, src_key_padding_mask
)
tgt_ll_shape = k2.ragged.create_ragged_shape2(
row_splits=None,
row_ids=tgt_ll_shape_row_ids,
cached_tot_size=tgt_ll_shape_row_ids.numel(),
)
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
masked_lm_scores = ragged_tgt_ll.max().to(device)
else:
logging.warning(f"Disable masked lm. path per utt is: {path_per_utt}")
masked_lm_scores = torch.zeros_like(am_scores.values)
tgt_nll = masked_lm_model.decoder_nll(
masked_lm_memory,
masked_lm_pos_emb,
src_symbols,
tgt_symbols,
src_key_padding_mask,
)
# nll means negative log-likelihood
# ll means log-likelihood
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
tgt_ll_list.append(tgt_ll)
# hyp - ref
masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0]
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
token_ids = tokens.tolist()

View File

@ -17,16 +17,22 @@
"""
This file contains rescoring code for NN LMs, e.g., conformer LM.
Support an utterance has 3 paths:
(a, b, c)
Suppose that an utterance has 3 paths:
(a, b, c, d, e)
and we want to use a masked conformer LM to assign a likelihood to each path.
The following shows the steps:
(1) Select path pairs:
(a, b), (a, c)
(b, a), (b, c)
(c, a), (c, b)
(1) Select "a" as the reference path. Note: For ease of implementation,
we always select the first path as the reference one.
(2) We have the following path pairs:
(a, a) (a, b), (a, c), (a, d), (a, e)
Note: Even if we know the likelihood for the pair (a, a) is always 0,
we still list it here for ease of implementation.
(2) For each pair, e.g., for the pair (a, b),
@ -39,20 +45,27 @@ The following shows the steps:
(iv) Use "b" as "src" and its shifted version as "tgt".
We can get another likelihood value, denoted as "ab_other"
So for the path pair (a, b), (a, c), (b, a), (b, c), (c, a), and (c, b),
So for the path pair (a, a), (a, b), (a, c), (a, d), and (a, e),
we can get the following log-likelihood values, viewed as two tensors:
self = [ab_self, ac_self, ba_self, bc_self, ca_self, cb_self]
self = [aa_self, ab_self, ac_self, ad_self, ae_self]
other = [ab_other, ac_other, ba_other, bc_other, ca_other, cb_other]
other = [aa_other, ab_other, ac_other, ad_other, ae_other]
Compute the difference the two tensors:
Compute the difference between the two tensors:
self - other = [ab_self - ab_other, ac_self - ac_other, ...]
other - self = [aa_other - aa_self, ab_other - ab_self,
ac_other - ac_self, ... ]
The log-likelihood for path a is : max(ab_self - ab_other, ac_self - ac_other)
The log-likelihood for path b is : max(ba_self - ba_other, bc_self - bc_other)
The log-likelihood for path c is : max(ca_self - ca_other, cb_self - cb_other)
The log-likelihood for path a is : 0
The log-likelihood for path b is : ab_other - ab_self
The log-likelihood for path c is : ac_other - ac_self
The log-likelihood for path d is : ad_other - ad_self
The log-likelihood for path e is : ae_other - ae_self
Note: "ab_other - ab_self" can be interpreted as
log P(b) - log P(a)
"""
from typing import Tuple
@ -181,14 +194,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
return concat(ragged, eos_id, direction="right")
def make_hyp_to_ref_map(row_splits: torch.Tensor):
def make_hyp_to_ref_map(row_splits: torch.Tensor) -> torch.Tensor:
"""
TODO: Add documentation.
TODO: Add documentation
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
>>> make_hyp_to_ref_map(row_splits)
tensor([0, 0, 1, 1, 2, 2, 3, 4], dtype=torch.int32)
tensor([0, 0, 0, 3, 3], dtype=torch.int32)
"""
device = row_splits.device
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
@ -198,78 +210,15 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor):
for size, offset in zip(sizes, offsets):
# Explanation of the following operations
# assume size is 3, offset is 2
# torch.arange() + offset is [2, 3, 4]
# expand() is [[2, 3, 4], [2, 3, 4]]
# t() is [[2, 2], [3, 3], [4, 4]]
# reshape() is [2, 2, 3, 3, 4, 4]
# torch.zeros() + offset is [2, 2, 2]
map_tensor = (
(torch.arange(size, dtype=torch.int32, device=device) + offset)
.expand(size - 1, size)
.t()
.reshape(-1)
torch.zeros(size, dtype=torch.int32, device=device) + offset
)
map_tensor_list.append(map_tensor)
return torch.cat(map_tensor_list)
def make_repeat_map(row_splits: torch.Tensor):
"""
TODO: Add documentation.
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
>>> make_repeat_map(row_splits)
tensor([1, 2, 0, 2, 0, 1, 4, 3], dtype=torch.int32)
"""
device = row_splits.device
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
offsets = row_splits[:-1]
map_tensor_list = []
for size, offset in zip(sizes, offsets):
# Explanation of the following operations
# assume size is 3, offset is 2
# torch.arange() + offset is [2, 3, 4]
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
# reshape() is [2, 3, 4, 2, 3, 4, 2, 3, 4]
map_tensor = (
(torch.arange(size, dtype=torch.int32, device=device) + offset)
.expand(size, size)
.reshape(-1)
)
diag_offset = torch.arange(size, device=device) * (size + 1)
# remove diagonal elements
map_tensor[diag_offset] = -1
map_tensor = map_tensor[map_tensor != -1]
# In the above example, map_tensor becomes
# [3, 4, 2, 4, 2, 3]
map_tensor_list.append(map_tensor)
return torch.cat(map_tensor_list)
def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
"""Repeat paths in an utterance.
For instance, if an utterance contains 3 paths: [path1 path2 path3],
after repeating, this utterance will contain 6 paths:
[path2 path3] [path1 path3] [path1 path2]
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
>>> tokens.to_str_simple()
'RaggedTensor([[[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]]], dtype=torch.int32)'
>>> make_repeat(tokens).to_str_simple()
'RaggedTensor([[[4, 5], [9], [1, 2, 3], [9], [1, 2, 3], [4, 5]], [[10, 1], [5, 8]]], dtype=torch.int32)' # noqa
TODO: Add documentation.
"""
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
indexes = make_repeat_map(tokens.shape.row_splits(1))
return tokens.index(axis=1, indexes=indexes)[0]
def compute_alignment(
tokens: k2.RaggedTensor,
shape: k2.RaggedShape,
@ -285,15 +234,13 @@ def compute_alignment(
"""
assert tokens.tot_size(0) == shape.tot_size(1)
device = tokens.device
utt_path_shape = shape.compose(tokens.shape)
utt_path_token = k2.RaggedTensor(utt_path_shape, tokens.values)
utt_path_token_repeated = make_repeat(utt_path_token)
path_token_repeated = utt_path_token_repeated.remove_axis(0)
hyps = k2.levenshtein_graph(tokens, device=device)
refs = k2.levenshtein_graph(tokens, device=device)
hyps = k2.levenshtein_graph(path_token_repeated, device=device)
hyp_to_ref_map = make_hyp_to_ref_map(utt_path_shape.row_splits(1))
hyp_to_ref_map = make_hyp_to_ref_map(shape.row_splits(1))
alignment = k2.levenshtein_alignment(
refs=refs, hyps=hyps, hyp_to_ref_map=hyp_to_ref_map
)

View File

@ -22,8 +22,6 @@ from icefall.lm.rescore import (
add_eos,
compute_alignment,
make_hyp_to_ref_map,
make_repeat,
make_repeat_map,
prepare_conformer_lm_inputs,
)
@ -67,42 +65,9 @@ def test_pad():
def test_make_hyp_to_ref_map():
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
row_splits = a.shape.row_splits(1)
repeat_map = make_hyp_to_ref_map(row_splits)
# fmt: off
expected = torch.tensor([0, 0, 1, 1, 2, 2, 3,
3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]).to(repeat_map) # noqa
# fmt: on
assert torch.all(torch.eq(repeat_map, expected))
def test_make_repeat_map():
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
row_splits = a.shape.row_splits(1)
repeat_map = make_repeat_map(row_splits)
# fmt: off
expected = torch.tensor([1, 2, 0, 2, 0, 1,
4, 5, 6, 3, 5, 6, 3, 4, 6, # noqa
3, 4, 5]).to(repeat_map) # noqa
# fmt: on
assert torch.all(torch.eq(repeat_map, expected))
def test_make_repeat():
# fmt: off
a = k2.RaggedTensor([
[[1, 3, 5], [2, 6]],
[[1, 2, 3, 4], [2], [], [9, 10, 11]],
])
b = make_repeat(a)
expected = k2.RaggedTensor([
[[2, 6], [1, 3, 5]],
[ [2], [], [9, 10, 11], # noqa
[1, 2, 3, 4], [], [9, 10, 11], # noqa
[1, 2, 3, 4], [2], [9, 10, 11], # noqa
[1, 2, 3, 4], [2], [], ], # noqa
])
# fmt: on
assert str(b) == str(expected)
hyp_to_ref_map = make_hyp_to_ref_map(row_splits)
expected = torch.tensor([0, 0, 0, 3, 3, 3, 3]).to(hyp_to_ref_map)
assert torch.all(torch.eq(hyp_to_ref_map, expected))
def test_compute_alignment():
@ -140,9 +105,7 @@ def main():
test_add_bos()
test_add_eos()
test_pad()
test_make_repeat_map()
test_make_hyp_to_ref_map()
test_make_repeat()
test_compute_alignment()