Use correct path pairs to compute log-likelihood.

This commit is contained in:
Fangjun Kuang 2021-11-15 10:01:16 +08:00
parent cdd539e55c
commit d680b56c5c
4 changed files with 140 additions and 84 deletions

View File

@ -498,6 +498,8 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
# if batch_idx > 10:
# break
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -539,7 +541,7 @@ def save_results(
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
if params.method == "attention-decoder": if params.method in ("attention-decoder", "conformer-lm"):
# Set it to False since there are too many logs. # Set it to False since there are too many logs.
enable_log = False enable_log = False
else: else:
@ -591,7 +593,7 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") setup_logger(f"{params.exp_dir}/log-{params.method}-2/log-decode")
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params) logging.info(params)
@ -714,9 +716,15 @@ def main():
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
model.device = device
if params.method == "conformer-lm": if params.method == "conformer-lm":
logging.info("Loading conformer lm model") logging.info("Loading conformer lm model")
assert torch.cuda.device_count() > 1, f"{torch.cuda.device_count()}"
# We use a second GPU for masked LM model as it causes OOM
# with 1 GPU
device2 = torch.device("cuda", 1)
# Note: If the parameters does not match # Note: If the parameters does not match
# the one used to save the checkpoint, it will # the one used to save the checkpoint, it will
# throw while calling `load_state_dict`. # throw while calling `load_state_dict`.
@ -740,10 +748,12 @@ def main():
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt" f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
) )
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
masked_lm_model.to(device) masked_lm_model.to(device2)
masked_lm_model.load_state_dict( masked_lm_model.load_state_dict(
average_checkpoints(filenames, device=device) average_checkpoints(filenames, device=device2)
) )
masked_lm_model.to(device2)
masked_lm_model.device = device2
else: else:
masked_lm_model = None masked_lm_model = None
@ -756,6 +766,8 @@ def main():
# #
test_sets = ["test-clean", "test-other"] test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
# if test_set == "test-clean":
# continue
results_dict = decode_dataset( results_dict = decode_dataset(
dl=test_dl, dl=test_dl,
params=params, params=params,

View File

@ -870,6 +870,7 @@ def rescore_with_attention_decoder(
ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
ngram_lm_scale_list = [ngram_lm_scale] ngram_lm_scale_list = [ngram_lm_scale]
@ -877,6 +878,7 @@ def rescore_with_attention_decoder(
attention_scale_list = [0.01, 0.05, 0.08] attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]
@ -987,7 +989,23 @@ def rescore_with_conformer_lm(
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
tokens = tokens.remove_values_leq(0) tokens = tokens.remove_values_leq(0)
alignment = compute_alignment(tokens, nbest.shape) device = model.device
# import pdb
#
# pdb.set_trace()
path_per_utt = (
nbest.shape.row_splits(1)[1:] - nbest.shape.row_splits(1)[:-1]
)
logging.info(f"path per utt: {path_per_utt}")
if 1 not in path_per_utt:
device2 = masked_lm_model.device
alignment = compute_alignment(
tokens.to(device2), nbest.shape.to(device2)
)
tgt_ll_list = []
for label_name in ["ref_labels", "hyp_labels"]:
( (
masked_src_symbols, masked_src_symbols,
src_symbols, src_symbols,
@ -999,6 +1017,7 @@ def rescore_with_conformer_lm(
bos_id=sos_id, bos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
blank_id=blank_id, blank_id=blank_id,
src_label_name=label_name,
unmasked_weight=0.0, unmasked_weight=0.0,
) )
@ -1022,11 +1041,15 @@ def rescore_with_conformer_lm(
# ll means log-likelihood # ll means log-likelihood
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1) tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
# Note: log-likelihood for those pairs that have identical src/tgt are 0 tgt_ll_list.append(tgt_ll)
# since their tgt_weights is 0
# tgt_ll = tgt_ll_list[1] - tgt_ll_list[0] # wer: 2.61
tgt_ll = tgt_ll_list[0] - tgt_ll_list[1]
# TODO(fangjun): Add documentation about why we do the following # TODO(fangjun): Add documentation about why we do the following
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1)) tgt_ll_shape_row_ids = make_hyp_to_ref_map(
nbest.shape.row_splits(1).to(device2)
)
tgt_ll_shape = k2.ragged.create_ragged_shape2( tgt_ll_shape = k2.ragged.create_ragged_shape2(
row_splits=None, row_splits=None,
row_ids=tgt_ll_shape_row_ids, row_ids=tgt_ll_shape_row_ids,
@ -1035,7 +1058,10 @@ def rescore_with_conformer_lm(
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll) ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0) ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
masked_lm_scores = ragged_tgt_ll.max() masked_lm_scores = ragged_tgt_ll.max().to(device)
else:
logging.warning(f"Disable masked lm. path per utt is: {path_per_utt}")
masked_lm_scores = torch.zeros_like(am_scores.values)
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly. # TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
token_ids = tokens.tolist() token_ids = tokens.tolist()
@ -1056,6 +1082,7 @@ def rescore_with_conformer_lm(
ngram_lm_scale_list = [0.01, 0.05, 0.08] ngram_lm_scale_list = [0.01, 0.05, 0.08]
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
ngram_lm_scale_list = [ngram_lm_scale] ngram_lm_scale_list = [ngram_lm_scale]
@ -1063,6 +1090,7 @@ def rescore_with_conformer_lm(
attention_scale_list = [0.01, 0.05, 0.08] attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
attention_scale_list = [attention_scale] attention_scale_list = [attention_scale]
@ -1070,6 +1098,7 @@ def rescore_with_conformer_lm(
masked_lm_scale_list = [0.01, 0.05, 0.08] masked_lm_scale_list = [0.01, 0.05, 0.08]
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0] masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0] masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
masked_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else: else:
masked_lm_scale_list = [masked_lm_scale] masked_lm_scale_list = [masked_lm_scale]

View File

@ -168,7 +168,7 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor):
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32) >>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
>>> make_hyp_to_ref_map(row_splits) >>> make_hyp_to_ref_map(row_splits)
tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4], dtype=torch.int32) tensor([0, 0, 1, 1, 2, 2, 3, 4], dtype=torch.int32)
""" """
device = row_splits.device device = row_splits.device
@ -180,12 +180,12 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor):
# Explanation of the following operations # Explanation of the following operations
# assume size is 3, offset is 2 # assume size is 3, offset is 2
# torch.arange() + offset is [2, 3, 4] # torch.arange() + offset is [2, 3, 4]
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]] # expand() is [[2, 3, 4], [2, 3, 4]]
# t() is [[2, 2, 2], [3, 3, 3], [4, 4, 4]] # t() is [[2, 2], [3, 3], [4, 4]]
# reshape() is [2, 2, 2, 3, 3, 3, 4, 4, 4] # reshape() is [2, 2, 3, 3, 4, 4]
map_tensor = ( map_tensor = (
(torch.arange(size, dtype=torch.int32, device=device) + offset) (torch.arange(size, dtype=torch.int32, device=device) + offset)
.expand(size, size) .expand(size - 1, size)
.t() .t()
.reshape(-1) .reshape(-1)
) )
@ -219,6 +219,12 @@ def make_repeat_map(row_splits: torch.Tensor):
.expand(size, size) .expand(size, size)
.reshape(-1) .reshape(-1)
) )
diag_offset = torch.arange(size, device=device) * (size + 1)
# remove diagonal elements
map_tensor[diag_offset] = -1
map_tensor = map_tensor[map_tensor != -1]
# In the above example, map_tensor becomes
# [3, 4, 2, 4, 2, 3]
map_tensor_list.append(map_tensor) map_tensor_list.append(map_tensor)
return torch.cat(map_tensor_list) return torch.cat(map_tensor_list)
@ -233,25 +239,17 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
[path1 path2 path3] [path1 path2 path3] [path1 path2 path3] [path1 path2 path3] [path1 path2 path3] [path1 path2 path3]
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ]) >>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
>>> tokens >>> tokens.to_str_simple()
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ] 'RaggedTensor([[[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]]], dtype=torch.int32)'
>>> make_repeat(tokens) >>> make_repeat(tokens).to_str_simple()
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa 'RaggedTensor([[[4, 5], [9], [1, 2, 3], [9], [1, 2, 3], [4, 5]], [[10, 1], [5, 8]]], dtype=torch.int32)' # noqa
TODO: Add documentation. TODO: Add documentation.
""" """
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}" assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
if True:
indexes = make_repeat_map(tokens.shape.row_splits(1)) indexes = make_repeat_map(tokens.shape.row_splits(1))
return tokens.index(axis=1, indexes=indexes)[0] return tokens.index(axis=1, indexes=indexes)[0]
else:
# This branch produces the same result as the above branch.
# It's more readable. Will remove it later.
repeated = []
for p in tokens.tolist():
repeated.append(p * len(p))
return k2.RaggedTensor(repeated).to(tokens.device)
def compute_alignment( def compute_alignment(
@ -289,7 +287,8 @@ def prepare_conformer_lm_inputs(
bos_id: int, bos_id: int,
eos_id: int, eos_id: int,
blank_id: int, blank_id: int,
unmasked_weight: float = 0.25, src_label_name: str,
unmasked_weight: float = 0.0,
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
]: ]:
@ -299,7 +298,18 @@ def prepare_conformer_lm_inputs(
Args: Args:
alignments: alignments:
It is computed by :func:`compute_alignment` It is computed by :func:`compute_alignment`
bos_id:
ID of the bos symbol.
eos_id:
ID of the eos symbol.
blank_id:
ID of the blank symbol.
src_label_name:
The name of the attribute from `alignment` that will be used for `src`.
`tgt` is a shift version of `src`. Valid values are: "ref_labels"
and "hyp_labels".
""" """
assert src_label_name in ("ref_labels", "hyp_labels")
device = alignment.device device = alignment.device
# alignment.arcs.shape has axes [fsa][state][arc] # alignment.arcs.shape has axes [fsa][state][arc]
# we remove axis 1, i.e., state, here # we remove axis 1, i.e., state, here
@ -313,13 +323,13 @@ def prepare_conformer_lm_inputs(
mode="constant", padding_value=blank_id mode="constant", padding_value=blank_id
) )
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels) src = k2.RaggedTensor(labels_shape, getattr(alignment, src_label_name))
src = src.remove_values_eq(-1) src = src.remove_values_eq(-1)
bos_src = add_bos(src, bos_id=bos_id) bos_src = add_bos(src, bos_id=bos_id)
bos_src_eos = add_eos(bos_src, eos_id=eos_id) bos_src_eos = add_eos(bos_src, eos_id=eos_id)
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id) bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels) tgt = k2.RaggedTensor(labels_shape, getattr(alignment, src_label_name))
# TODO: Do we need to remove 0s from tgt ? # TODO: Do we need to remove 0s from tgt ?
tgt = tgt.remove_values_eq(-1) tgt = tgt.remove_values_eq(-1)
tgt_eos = add_eos(tgt, eos_id=eos_id) tgt_eos = add_eos(tgt, eos_id=eos_id)
@ -342,7 +352,7 @@ def prepare_conformer_lm_inputs(
) )
# find unmasked positions # find unmasked positions
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1] unmasked_positions = bos_masked_src_eos_pad[:, 1:] != 0
weight[unmasked_positions] = unmasked_weight weight[unmasked_positions] = unmasked_weight
# set weights for paddings # set weights for paddings

View File

@ -69,8 +69,8 @@ def test_make_hyp_to_ref_map():
row_splits = a.shape.row_splits(1) row_splits = a.shape.row_splits(1)
repeat_map = make_hyp_to_ref_map(row_splits) repeat_map = make_hyp_to_ref_map(row_splits)
# fmt: off # fmt: off
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, expected = torch.tensor([0, 0, 1, 1, 2, 2, 3,
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]).to(repeat_map) # noqa
# fmt: on # fmt: on
assert torch.all(torch.eq(repeat_map, expected)) assert torch.all(torch.eq(repeat_map, expected))
@ -80,9 +80,9 @@ def test_make_repeat_map():
row_splits = a.shape.row_splits(1) row_splits = a.shape.row_splits(1)
repeat_map = make_repeat_map(row_splits) repeat_map = make_repeat_map(row_splits)
# fmt: off # fmt: off
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2, expected = torch.tensor([1, 2, 0, 2, 0, 1,
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa 4, 5, 6, 3, 5, 6, 3, 4, 6, # noqa
3, 4, 5, 6]).to(repeat_map) # noqa 3, 4, 5]).to(repeat_map) # noqa
# fmt: on # fmt: on
assert torch.all(torch.eq(repeat_map, expected)) assert torch.all(torch.eq(repeat_map, expected))
@ -95,11 +95,11 @@ def test_make_repeat():
]) ])
b = make_repeat(a) b = make_repeat(a)
expected = k2.RaggedTensor([ expected = k2.RaggedTensor([
[[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]], [[2, 6], [1, 3, 5]],
[[1, 2, 3, 4], [2], [], [9, 10, 11], [ [2], [], [9, 10, 11], # noqa
[1, 2, 3, 4], [2], [], [9, 10, 11], [1, 2, 3, 4], [], [9, 10, 11], # noqa
[1, 2, 3, 4], [2], [], [9, 10, 11], [1, 2, 3, 4], [2], [9, 10, 11], # noqa
[1, 2, 3, 4], [2], [], [9, 10, 11]], [1, 2, 3, 4], [2], [], ], # noqa
]) ])
# fmt: on # fmt: on
assert str(b) == str(expected) assert str(b) == str(expected)
@ -116,19 +116,24 @@ def test_compute_alignment():
# fmt: on # fmt: on
shape = k2.RaggedShape("[[x x x] [x x]]") shape = k2.RaggedShape("[[x x x] [x x]]")
alignment = compute_alignment(tokens, shape) alignment = compute_alignment(tokens, shape)
print(alignment.ref_labels)
print(alignment.hyp_labels)
print(alignment.labels)
( (
masked_src, masked_src,
src, src,
tgt, tgt,
src_key_padding_mask, src_key_padding_mask,
weight, weight,
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0) ) = prepare_conformer_lm_inputs(
alignment, bos_id=10, eos_id=20, blank_id=0, src_label_name="hyp_labels"
)
# print("masked src", masked_src) print("masked src", masked_src)
# print("src", src) print("src", src)
# print("tgt", tgt) print("tgt", tgt)
# print("src_key_padding_mask", src_key_padding_mask) print("src_key_padding_mask", src_key_padding_mask)
# print("weight", weight) print("weight", weight)
def main(): def main():