mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fixes after review.
This commit is contained in:
parent
57b9c8868b
commit
878fb40a12
@ -20,11 +20,7 @@ from typing import Dict, List, Optional, Union
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lm.rescore import (
|
||||
compute_alignment,
|
||||
make_hyp_to_ref_map,
|
||||
prepare_conformer_lm_inputs,
|
||||
)
|
||||
from icefall.lm.rescore import compute_alignment, prepare_conformer_lm_inputs
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
@ -989,8 +985,6 @@ def rescore_with_conformer_lm(
|
||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
tokens = tokens.remove_values_leq(0)
|
||||
|
||||
device = model.device
|
||||
|
||||
# import pdb
|
||||
#
|
||||
# pdb.set_trace()
|
||||
@ -998,70 +992,51 @@ def rescore_with_conformer_lm(
|
||||
nbest.shape.row_splits(1)[1:] - nbest.shape.row_splits(1)[:-1]
|
||||
)
|
||||
logging.info(f"path per utt: {path_per_utt}")
|
||||
if 1 not in path_per_utt:
|
||||
device2 = masked_lm_model.device
|
||||
|
||||
alignment = compute_alignment(
|
||||
tokens.to(device2), nbest.shape.to(device2)
|
||||
device2 = masked_lm_model.device
|
||||
|
||||
alignment = compute_alignment(tokens.to(device2), nbest.shape.to(device2))
|
||||
tgt_ll_list = []
|
||||
for label_name in ["ref_labels", "hyp_labels"]:
|
||||
(
|
||||
masked_src_symbols,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
tgt_weights,
|
||||
) = prepare_conformer_lm_inputs(
|
||||
alignment,
|
||||
bos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=blank_id,
|
||||
src_label_name=label_name,
|
||||
unmasked_weight=0.0,
|
||||
)
|
||||
tgt_ll_list = []
|
||||
for label_name in ["ref_labels", "hyp_labels"]:
|
||||
(
|
||||
masked_src_symbols,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
tgt_weights,
|
||||
) = prepare_conformer_lm_inputs(
|
||||
alignment,
|
||||
bos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=blank_id,
|
||||
src_label_name=label_name,
|
||||
unmasked_weight=0.0,
|
||||
)
|
||||
|
||||
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||
src_symbols = src_symbols.to(torch.int64)
|
||||
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||
src_symbols = src_symbols.to(torch.int64)
|
||||
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||
|
||||
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||
masked_src_symbols, src_key_padding_mask
|
||||
)
|
||||
|
||||
tgt_nll = masked_lm_model.decoder_nll(
|
||||
masked_lm_memory,
|
||||
masked_lm_pos_emb,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
)
|
||||
|
||||
# nll means negative log-likelihood
|
||||
# ll means log-likelihood
|
||||
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||
|
||||
tgt_ll_list.append(tgt_ll)
|
||||
|
||||
# tgt_ll = tgt_ll_list[1] - tgt_ll_list[0] # wer: 2.61
|
||||
tgt_ll = tgt_ll_list[0] - tgt_ll_list[1]
|
||||
|
||||
# TODO(fangjun): Add documentation about why we do the following
|
||||
tgt_ll_shape_row_ids = make_hyp_to_ref_map(
|
||||
nbest.shape.row_splits(1).to(device2)
|
||||
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||
masked_src_symbols, src_key_padding_mask
|
||||
)
|
||||
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=None,
|
||||
row_ids=tgt_ll_shape_row_ids,
|
||||
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
||||
)
|
||||
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
||||
|
||||
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
||||
masked_lm_scores = ragged_tgt_ll.max().to(device)
|
||||
else:
|
||||
logging.warning(f"Disable masked lm. path per utt is: {path_per_utt}")
|
||||
masked_lm_scores = torch.zeros_like(am_scores.values)
|
||||
tgt_nll = masked_lm_model.decoder_nll(
|
||||
masked_lm_memory,
|
||||
masked_lm_pos_emb,
|
||||
src_symbols,
|
||||
tgt_symbols,
|
||||
src_key_padding_mask,
|
||||
)
|
||||
|
||||
# nll means negative log-likelihood
|
||||
# ll means log-likelihood
|
||||
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||
|
||||
tgt_ll_list.append(tgt_ll)
|
||||
|
||||
# hyp - ref
|
||||
masked_lm_scores = tgt_ll_list[1] - tgt_ll_list[0]
|
||||
|
||||
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||
token_ids = tokens.tolist()
|
||||
|
@ -17,16 +17,22 @@
|
||||
"""
|
||||
This file contains rescoring code for NN LMs, e.g., conformer LM.
|
||||
|
||||
Support an utterance has 3 paths:
|
||||
(a, b, c)
|
||||
Suppose that an utterance has 3 paths:
|
||||
(a, b, c, d, e)
|
||||
|
||||
and we want to use a masked conformer LM to assign a likelihood to each path.
|
||||
|
||||
The following shows the steps:
|
||||
|
||||
(1) Select path pairs:
|
||||
(a, b), (a, c)
|
||||
(b, a), (b, c)
|
||||
(c, a), (c, b)
|
||||
(1) Select "a" as the reference path. Note: For ease of implementation,
|
||||
we always select the first path as the reference one.
|
||||
|
||||
(2) We have the following path pairs:
|
||||
|
||||
(a, a) (a, b), (a, c), (a, d), (a, e)
|
||||
|
||||
Note: Even if we know the likelihood for the pair (a, a) is always 0,
|
||||
we still list it here for ease of implementation.
|
||||
|
||||
(2) For each pair, e.g., for the pair (a, b),
|
||||
|
||||
@ -39,20 +45,27 @@ The following shows the steps:
|
||||
(iv) Use "b" as "src" and its shifted version as "tgt".
|
||||
We can get another likelihood value, denoted as "ab_other"
|
||||
|
||||
So for the path pair (a, b), (a, c), (b, a), (b, c), (c, a), and (c, b),
|
||||
So for the path pair (a, a), (a, b), (a, c), (a, d), and (a, e),
|
||||
we can get the following log-likelihood values, viewed as two tensors:
|
||||
|
||||
self = [ab_self, ac_self, ba_self, bc_self, ca_self, cb_self]
|
||||
self = [aa_self, ab_self, ac_self, ad_self, ae_self]
|
||||
|
||||
other = [ab_other, ac_other, ba_other, bc_other, ca_other, cb_other]
|
||||
other = [aa_other, ab_other, ac_other, ad_other, ae_other]
|
||||
|
||||
Compute the difference the two tensors:
|
||||
Compute the difference between the two tensors:
|
||||
|
||||
self - other = [ab_self - ab_other, ac_self - ac_other, ...]
|
||||
other - self = [aa_other - aa_self, ab_other - ab_self,
|
||||
ac_other - ac_self, ... ]
|
||||
|
||||
The log-likelihood for path a is : max(ab_self - ab_other, ac_self - ac_other)
|
||||
The log-likelihood for path b is : max(ba_self - ba_other, bc_self - bc_other)
|
||||
The log-likelihood for path c is : max(ca_self - ca_other, cb_self - cb_other)
|
||||
The log-likelihood for path a is : 0
|
||||
The log-likelihood for path b is : ab_other - ab_self
|
||||
The log-likelihood for path c is : ac_other - ac_self
|
||||
The log-likelihood for path d is : ad_other - ad_self
|
||||
The log-likelihood for path e is : ae_other - ae_self
|
||||
|
||||
Note: "ab_other - ab_self" can be interpreted as
|
||||
|
||||
log P(b) - log P(a)
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
@ -181,14 +194,13 @@ def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor:
|
||||
return concat(ragged, eos_id, direction="right")
|
||||
|
||||
|
||||
def make_hyp_to_ref_map(row_splits: torch.Tensor):
|
||||
def make_hyp_to_ref_map(row_splits: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
TODO: Add documentation.
|
||||
TODO: Add documentation
|
||||
|
||||
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||
>>> make_hyp_to_ref_map(row_splits)
|
||||
tensor([0, 0, 1, 1, 2, 2, 3, 4], dtype=torch.int32)
|
||||
|
||||
tensor([0, 0, 0, 3, 3], dtype=torch.int32)
|
||||
"""
|
||||
device = row_splits.device
|
||||
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
|
||||
@ -198,78 +210,15 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor):
|
||||
for size, offset in zip(sizes, offsets):
|
||||
# Explanation of the following operations
|
||||
# assume size is 3, offset is 2
|
||||
# torch.arange() + offset is [2, 3, 4]
|
||||
# expand() is [[2, 3, 4], [2, 3, 4]]
|
||||
# t() is [[2, 2], [3, 3], [4, 4]]
|
||||
# reshape() is [2, 2, 3, 3, 4, 4]
|
||||
# torch.zeros() + offset is [2, 2, 2]
|
||||
map_tensor = (
|
||||
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||
.expand(size - 1, size)
|
||||
.t()
|
||||
.reshape(-1)
|
||||
torch.zeros(size, dtype=torch.int32, device=device) + offset
|
||||
)
|
||||
map_tensor_list.append(map_tensor)
|
||||
|
||||
return torch.cat(map_tensor_list)
|
||||
|
||||
|
||||
def make_repeat_map(row_splits: torch.Tensor):
|
||||
"""
|
||||
TODO: Add documentation.
|
||||
|
||||
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||
>>> make_repeat_map(row_splits)
|
||||
tensor([1, 2, 0, 2, 0, 1, 4, 3], dtype=torch.int32)
|
||||
|
||||
"""
|
||||
device = row_splits.device
|
||||
sizes = (row_splits[1:] - row_splits[:-1]).tolist()
|
||||
offsets = row_splits[:-1]
|
||||
|
||||
map_tensor_list = []
|
||||
for size, offset in zip(sizes, offsets):
|
||||
# Explanation of the following operations
|
||||
# assume size is 3, offset is 2
|
||||
# torch.arange() + offset is [2, 3, 4]
|
||||
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
|
||||
# reshape() is [2, 3, 4, 2, 3, 4, 2, 3, 4]
|
||||
map_tensor = (
|
||||
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||
.expand(size, size)
|
||||
.reshape(-1)
|
||||
)
|
||||
diag_offset = torch.arange(size, device=device) * (size + 1)
|
||||
# remove diagonal elements
|
||||
map_tensor[diag_offset] = -1
|
||||
map_tensor = map_tensor[map_tensor != -1]
|
||||
# In the above example, map_tensor becomes
|
||||
# [3, 4, 2, 4, 2, 3]
|
||||
map_tensor_list.append(map_tensor)
|
||||
|
||||
return torch.cat(map_tensor_list)
|
||||
|
||||
|
||||
def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
||||
"""Repeat paths in an utterance.
|
||||
|
||||
For instance, if an utterance contains 3 paths: [path1 path2 path3],
|
||||
after repeating, this utterance will contain 6 paths:
|
||||
[path2 path3] [path1 path3] [path1 path2]
|
||||
|
||||
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
|
||||
>>> tokens.to_str_simple()
|
||||
'RaggedTensor([[[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]]], dtype=torch.int32)'
|
||||
>>> make_repeat(tokens).to_str_simple()
|
||||
'RaggedTensor([[[4, 5], [9], [1, 2, 3], [9], [1, 2, 3], [4, 5]], [[10, 1], [5, 8]]], dtype=torch.int32)' # noqa
|
||||
|
||||
TODO: Add documentation.
|
||||
|
||||
"""
|
||||
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
|
||||
indexes = make_repeat_map(tokens.shape.row_splits(1))
|
||||
return tokens.index(axis=1, indexes=indexes)[0]
|
||||
|
||||
|
||||
def compute_alignment(
|
||||
tokens: k2.RaggedTensor,
|
||||
shape: k2.RaggedShape,
|
||||
@ -285,15 +234,13 @@ def compute_alignment(
|
||||
"""
|
||||
assert tokens.tot_size(0) == shape.tot_size(1)
|
||||
device = tokens.device
|
||||
utt_path_shape = shape.compose(tokens.shape)
|
||||
utt_path_token = k2.RaggedTensor(utt_path_shape, tokens.values)
|
||||
utt_path_token_repeated = make_repeat(utt_path_token)
|
||||
path_token_repeated = utt_path_token_repeated.remove_axis(0)
|
||||
|
||||
hyps = k2.levenshtein_graph(tokens, device=device)
|
||||
|
||||
refs = k2.levenshtein_graph(tokens, device=device)
|
||||
hyps = k2.levenshtein_graph(path_token_repeated, device=device)
|
||||
|
||||
hyp_to_ref_map = make_hyp_to_ref_map(utt_path_shape.row_splits(1))
|
||||
hyp_to_ref_map = make_hyp_to_ref_map(shape.row_splits(1))
|
||||
|
||||
alignment = k2.levenshtein_alignment(
|
||||
refs=refs, hyps=hyps, hyp_to_ref_map=hyp_to_ref_map
|
||||
)
|
||||
|
@ -22,8 +22,6 @@ from icefall.lm.rescore import (
|
||||
add_eos,
|
||||
compute_alignment,
|
||||
make_hyp_to_ref_map,
|
||||
make_repeat,
|
||||
make_repeat_map,
|
||||
prepare_conformer_lm_inputs,
|
||||
)
|
||||
|
||||
@ -67,42 +65,9 @@ def test_pad():
|
||||
def test_make_hyp_to_ref_map():
|
||||
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
||||
row_splits = a.shape.row_splits(1)
|
||||
repeat_map = make_hyp_to_ref_map(row_splits)
|
||||
# fmt: off
|
||||
expected = torch.tensor([0, 0, 1, 1, 2, 2, 3,
|
||||
3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]).to(repeat_map) # noqa
|
||||
# fmt: on
|
||||
assert torch.all(torch.eq(repeat_map, expected))
|
||||
|
||||
|
||||
def test_make_repeat_map():
|
||||
a = k2.RaggedTensor([[[1, 2], [], [3]], [[1, 3], [2], [4], [5]]])
|
||||
row_splits = a.shape.row_splits(1)
|
||||
repeat_map = make_repeat_map(row_splits)
|
||||
# fmt: off
|
||||
expected = torch.tensor([1, 2, 0, 2, 0, 1,
|
||||
4, 5, 6, 3, 5, 6, 3, 4, 6, # noqa
|
||||
3, 4, 5]).to(repeat_map) # noqa
|
||||
# fmt: on
|
||||
assert torch.all(torch.eq(repeat_map, expected))
|
||||
|
||||
|
||||
def test_make_repeat():
|
||||
# fmt: off
|
||||
a = k2.RaggedTensor([
|
||||
[[1, 3, 5], [2, 6]],
|
||||
[[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
||||
])
|
||||
b = make_repeat(a)
|
||||
expected = k2.RaggedTensor([
|
||||
[[2, 6], [1, 3, 5]],
|
||||
[ [2], [], [9, 10, 11], # noqa
|
||||
[1, 2, 3, 4], [], [9, 10, 11], # noqa
|
||||
[1, 2, 3, 4], [2], [9, 10, 11], # noqa
|
||||
[1, 2, 3, 4], [2], [], ], # noqa
|
||||
])
|
||||
# fmt: on
|
||||
assert str(b) == str(expected)
|
||||
hyp_to_ref_map = make_hyp_to_ref_map(row_splits)
|
||||
expected = torch.tensor([0, 0, 0, 3, 3, 3, 3]).to(hyp_to_ref_map)
|
||||
assert torch.all(torch.eq(hyp_to_ref_map, expected))
|
||||
|
||||
|
||||
def test_compute_alignment():
|
||||
@ -140,9 +105,7 @@ def main():
|
||||
test_add_bos()
|
||||
test_add_eos()
|
||||
test_pad()
|
||||
test_make_repeat_map()
|
||||
test_make_hyp_to_ref_map()
|
||||
test_make_repeat()
|
||||
test_compute_alignment()
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user