mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-28 11:14:19 +00:00
Use correct path pairs to compute log-likelihood.
This commit is contained in:
parent
cdd539e55c
commit
d680b56c5c
@ -498,6 +498,8 @@ def decode_dataset(
|
|||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
# if batch_idx > 10:
|
||||||
|
# break
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps_dict = decode_one_batch(
|
||||||
@ -539,7 +541,7 @@ def save_results(
|
|||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
):
|
):
|
||||||
if params.method == "attention-decoder":
|
if params.method in ("attention-decoder", "conformer-lm"):
|
||||||
# Set it to False since there are too many logs.
|
# Set it to False since there are too many logs.
|
||||||
enable_log = False
|
enable_log = False
|
||||||
else:
|
else:
|
||||||
@ -591,7 +593,7 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
setup_logger(f"{params.exp_dir}/log-{params.method}-2/log-decode")
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -714,9 +716,15 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
model.device = device
|
||||||
|
|
||||||
if params.method == "conformer-lm":
|
if params.method == "conformer-lm":
|
||||||
logging.info("Loading conformer lm model")
|
logging.info("Loading conformer lm model")
|
||||||
|
assert torch.cuda.device_count() > 1, f"{torch.cuda.device_count()}"
|
||||||
|
|
||||||
|
# We use a second GPU for masked LM model as it causes OOM
|
||||||
|
# with 1 GPU
|
||||||
|
device2 = torch.device("cuda", 1)
|
||||||
# Note: If the parameters does not match
|
# Note: If the parameters does not match
|
||||||
# the one used to save the checkpoint, it will
|
# the one used to save the checkpoint, it will
|
||||||
# throw while calling `load_state_dict`.
|
# throw while calling `load_state_dict`.
|
||||||
@ -740,10 +748,12 @@ def main():
|
|||||||
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
|
f"{params.conformer_lm_exp_dir}/epoch-{i}.pt"
|
||||||
)
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
masked_lm_model.to(device)
|
masked_lm_model.to(device2)
|
||||||
masked_lm_model.load_state_dict(
|
masked_lm_model.load_state_dict(
|
||||||
average_checkpoints(filenames, device=device)
|
average_checkpoints(filenames, device=device2)
|
||||||
)
|
)
|
||||||
|
masked_lm_model.to(device2)
|
||||||
|
masked_lm_model.device = device2
|
||||||
else:
|
else:
|
||||||
masked_lm_model = None
|
masked_lm_model = None
|
||||||
|
|
||||||
@ -756,6 +766,8 @@ def main():
|
|||||||
#
|
#
|
||||||
test_sets = ["test-clean", "test-other"]
|
test_sets = ["test-clean", "test-other"]
|
||||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||||
|
# if test_set == "test-clean":
|
||||||
|
# continue
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
|
@ -870,6 +870,7 @@ def rescore_with_attention_decoder(
|
|||||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
ngram_lm_scale_list = [ngram_lm_scale]
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
@ -877,6 +878,7 @@ def rescore_with_attention_decoder(
|
|||||||
attention_scale_list = [0.01, 0.05, 0.08]
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
attention_scale_list = [attention_scale]
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
@ -987,55 +989,79 @@ def rescore_with_conformer_lm(
|
|||||||
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||||
tokens = tokens.remove_values_leq(0)
|
tokens = tokens.remove_values_leq(0)
|
||||||
|
|
||||||
alignment = compute_alignment(tokens, nbest.shape)
|
device = model.device
|
||||||
(
|
|
||||||
masked_src_symbols,
|
# import pdb
|
||||||
src_symbols,
|
#
|
||||||
tgt_symbols,
|
# pdb.set_trace()
|
||||||
src_key_padding_mask,
|
path_per_utt = (
|
||||||
tgt_weights,
|
nbest.shape.row_splits(1)[1:] - nbest.shape.row_splits(1)[:-1]
|
||||||
) = prepare_conformer_lm_inputs(
|
|
||||||
alignment,
|
|
||||||
bos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
blank_id=blank_id,
|
|
||||||
unmasked_weight=0.0,
|
|
||||||
)
|
)
|
||||||
|
logging.info(f"path per utt: {path_per_utt}")
|
||||||
|
if 1 not in path_per_utt:
|
||||||
|
device2 = masked_lm_model.device
|
||||||
|
|
||||||
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
alignment = compute_alignment(
|
||||||
src_symbols = src_symbols.to(torch.int64)
|
tokens.to(device2), nbest.shape.to(device2)
|
||||||
tgt_symbols = tgt_symbols.to(torch.int64)
|
)
|
||||||
|
tgt_ll_list = []
|
||||||
|
for label_name in ["ref_labels", "hyp_labels"]:
|
||||||
|
(
|
||||||
|
masked_src_symbols,
|
||||||
|
src_symbols,
|
||||||
|
tgt_symbols,
|
||||||
|
src_key_padding_mask,
|
||||||
|
tgt_weights,
|
||||||
|
) = prepare_conformer_lm_inputs(
|
||||||
|
alignment,
|
||||||
|
bos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=blank_id,
|
||||||
|
src_label_name=label_name,
|
||||||
|
unmasked_weight=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
masked_src_symbols = masked_src_symbols.to(torch.int64)
|
||||||
masked_src_symbols, src_key_padding_mask
|
src_symbols = src_symbols.to(torch.int64)
|
||||||
)
|
tgt_symbols = tgt_symbols.to(torch.int64)
|
||||||
|
|
||||||
tgt_nll = masked_lm_model.decoder_nll(
|
masked_lm_memory, masked_lm_pos_emb = masked_lm_model(
|
||||||
masked_lm_memory,
|
masked_src_symbols, src_key_padding_mask
|
||||||
masked_lm_pos_emb,
|
)
|
||||||
src_symbols,
|
|
||||||
tgt_symbols,
|
|
||||||
src_key_padding_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# nll means negative log-likelihood
|
tgt_nll = masked_lm_model.decoder_nll(
|
||||||
# ll means log-likelihood
|
masked_lm_memory,
|
||||||
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
masked_lm_pos_emb,
|
||||||
|
src_symbols,
|
||||||
|
tgt_symbols,
|
||||||
|
src_key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
# Note: log-likelihood for those pairs that have identical src/tgt are 0
|
# nll means negative log-likelihood
|
||||||
# since their tgt_weights is 0
|
# ll means log-likelihood
|
||||||
|
tgt_ll = -1 * (tgt_nll * tgt_weights).sum(dim=-1)
|
||||||
|
|
||||||
# TODO(fangjun): Add documentation about why we do the following
|
tgt_ll_list.append(tgt_ll)
|
||||||
tgt_ll_shape_row_ids = make_hyp_to_ref_map(nbest.shape.row_splits(1))
|
|
||||||
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
|
||||||
row_splits=None,
|
|
||||||
row_ids=tgt_ll_shape_row_ids,
|
|
||||||
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
|
||||||
)
|
|
||||||
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
|
||||||
|
|
||||||
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
# tgt_ll = tgt_ll_list[1] - tgt_ll_list[0] # wer: 2.61
|
||||||
masked_lm_scores = ragged_tgt_ll.max()
|
tgt_ll = tgt_ll_list[0] - tgt_ll_list[1]
|
||||||
|
|
||||||
|
# TODO(fangjun): Add documentation about why we do the following
|
||||||
|
tgt_ll_shape_row_ids = make_hyp_to_ref_map(
|
||||||
|
nbest.shape.row_splits(1).to(device2)
|
||||||
|
)
|
||||||
|
tgt_ll_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=None,
|
||||||
|
row_ids=tgt_ll_shape_row_ids,
|
||||||
|
cached_tot_size=tgt_ll_shape_row_ids.numel(),
|
||||||
|
)
|
||||||
|
ragged_tgt_ll = k2.RaggedTensor(tgt_ll_shape, tgt_ll)
|
||||||
|
|
||||||
|
ragged_tgt_ll = ragged_tgt_ll.remove_values_eq(0)
|
||||||
|
masked_lm_scores = ragged_tgt_ll.max().to(device)
|
||||||
|
else:
|
||||||
|
logging.warning(f"Disable masked lm. path per utt is: {path_per_utt}")
|
||||||
|
masked_lm_scores = torch.zeros_like(am_scores.values)
|
||||||
|
|
||||||
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
# TODO(fangjun): Support passing a ragged tensor to `decoder_nll` directly.
|
||||||
token_ids = tokens.tolist()
|
token_ids = tokens.tolist()
|
||||||
@ -1056,6 +1082,7 @@ def rescore_with_conformer_lm(
|
|||||||
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
ngram_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
ngram_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
ngram_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
ngram_lm_scale_list = [ngram_lm_scale]
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
@ -1063,6 +1090,7 @@ def rescore_with_conformer_lm(
|
|||||||
attention_scale_list = [0.01, 0.05, 0.08]
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
attention_scale_list = [attention_scale]
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
@ -1070,6 +1098,7 @@ def rescore_with_conformer_lm(
|
|||||||
masked_lm_scale_list = [0.01, 0.05, 0.08]
|
masked_lm_scale_list = [0.01, 0.05, 0.08]
|
||||||
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
masked_lm_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
masked_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
masked_lm_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
else:
|
else:
|
||||||
masked_lm_scale_list = [masked_lm_scale]
|
masked_lm_scale_list = [masked_lm_scale]
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor):
|
|||||||
|
|
||||||
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
>>> row_splits = torch.tensor([0, 3, 5], dtype=torch.int32)
|
||||||
>>> make_hyp_to_ref_map(row_splits)
|
>>> make_hyp_to_ref_map(row_splits)
|
||||||
tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4], dtype=torch.int32)
|
tensor([0, 0, 1, 1, 2, 2, 3, 4], dtype=torch.int32)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
device = row_splits.device
|
device = row_splits.device
|
||||||
@ -180,12 +180,12 @@ def make_hyp_to_ref_map(row_splits: torch.Tensor):
|
|||||||
# Explanation of the following operations
|
# Explanation of the following operations
|
||||||
# assume size is 3, offset is 2
|
# assume size is 3, offset is 2
|
||||||
# torch.arange() + offset is [2, 3, 4]
|
# torch.arange() + offset is [2, 3, 4]
|
||||||
# expand() is [[2, 3, 4], [2, 3, 4], [2, 3, 4]]
|
# expand() is [[2, 3, 4], [2, 3, 4]]
|
||||||
# t() is [[2, 2, 2], [3, 3, 3], [4, 4, 4]]
|
# t() is [[2, 2], [3, 3], [4, 4]]
|
||||||
# reshape() is [2, 2, 2, 3, 3, 3, 4, 4, 4]
|
# reshape() is [2, 2, 3, 3, 4, 4]
|
||||||
map_tensor = (
|
map_tensor = (
|
||||||
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
(torch.arange(size, dtype=torch.int32, device=device) + offset)
|
||||||
.expand(size, size)
|
.expand(size - 1, size)
|
||||||
.t()
|
.t()
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
)
|
)
|
||||||
@ -219,6 +219,12 @@ def make_repeat_map(row_splits: torch.Tensor):
|
|||||||
.expand(size, size)
|
.expand(size, size)
|
||||||
.reshape(-1)
|
.reshape(-1)
|
||||||
)
|
)
|
||||||
|
diag_offset = torch.arange(size, device=device) * (size + 1)
|
||||||
|
# remove diagonal elements
|
||||||
|
map_tensor[diag_offset] = -1
|
||||||
|
map_tensor = map_tensor[map_tensor != -1]
|
||||||
|
# In the above example, map_tensor becomes
|
||||||
|
# [3, 4, 2, 4, 2, 3]
|
||||||
map_tensor_list.append(map_tensor)
|
map_tensor_list.append(map_tensor)
|
||||||
|
|
||||||
return torch.cat(map_tensor_list)
|
return torch.cat(map_tensor_list)
|
||||||
@ -233,25 +239,17 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
|||||||
[path1 path2 path3] [path1 path2 path3] [path1 path2 path3]
|
[path1 path2 path3] [path1 path2 path3] [path1 path2 path3]
|
||||||
|
|
||||||
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
|
>>> tokens = k2.RaggedTensor([ [[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]] ])
|
||||||
>>> tokens
|
>>> tokens.to_str_simple()
|
||||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] ] ]
|
'RaggedTensor([[[1, 2, 3], [4, 5], [9]], [[5, 8], [10, 1]]], dtype=torch.int32)'
|
||||||
>>> make_repeat(tokens)
|
>>> make_repeat(tokens).to_str_simple()
|
||||||
[ [ [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] [ 1 2 3 ] [ 4 5 ] [ 9 ] ] [ [ 5 8 ] [ 10 1 ] [ 5 8 ] [ 10 1 ] ] ] # noqa
|
'RaggedTensor([[[4, 5], [9], [1, 2, 3], [9], [1, 2, 3], [4, 5]], [[10, 1], [5, 8]]], dtype=torch.int32)' # noqa
|
||||||
|
|
||||||
TODO: Add documentation.
|
TODO: Add documentation.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
|
assert tokens.num_axes == 3, f"num_axes: {tokens.num_axes}"
|
||||||
if True:
|
indexes = make_repeat_map(tokens.shape.row_splits(1))
|
||||||
indexes = make_repeat_map(tokens.shape.row_splits(1))
|
return tokens.index(axis=1, indexes=indexes)[0]
|
||||||
return tokens.index(axis=1, indexes=indexes)[0]
|
|
||||||
else:
|
|
||||||
# This branch produces the same result as the above branch.
|
|
||||||
# It's more readable. Will remove it later.
|
|
||||||
repeated = []
|
|
||||||
for p in tokens.tolist():
|
|
||||||
repeated.append(p * len(p))
|
|
||||||
return k2.RaggedTensor(repeated).to(tokens.device)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_alignment(
|
def compute_alignment(
|
||||||
@ -289,7 +287,8 @@ def prepare_conformer_lm_inputs(
|
|||||||
bos_id: int,
|
bos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
unmasked_weight: float = 0.25,
|
src_label_name: str,
|
||||||
|
unmasked_weight: float = 0.0,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||||
]:
|
]:
|
||||||
@ -299,7 +298,18 @@ def prepare_conformer_lm_inputs(
|
|||||||
Args:
|
Args:
|
||||||
alignments:
|
alignments:
|
||||||
It is computed by :func:`compute_alignment`
|
It is computed by :func:`compute_alignment`
|
||||||
|
bos_id:
|
||||||
|
ID of the bos symbol.
|
||||||
|
eos_id:
|
||||||
|
ID of the eos symbol.
|
||||||
|
blank_id:
|
||||||
|
ID of the blank symbol.
|
||||||
|
src_label_name:
|
||||||
|
The name of the attribute from `alignment` that will be used for `src`.
|
||||||
|
`tgt` is a shift version of `src`. Valid values are: "ref_labels"
|
||||||
|
and "hyp_labels".
|
||||||
"""
|
"""
|
||||||
|
assert src_label_name in ("ref_labels", "hyp_labels")
|
||||||
device = alignment.device
|
device = alignment.device
|
||||||
# alignment.arcs.shape has axes [fsa][state][arc]
|
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||||
# we remove axis 1, i.e., state, here
|
# we remove axis 1, i.e., state, here
|
||||||
@ -313,13 +323,13 @@ def prepare_conformer_lm_inputs(
|
|||||||
mode="constant", padding_value=blank_id
|
mode="constant", padding_value=blank_id
|
||||||
)
|
)
|
||||||
|
|
||||||
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
|
src = k2.RaggedTensor(labels_shape, getattr(alignment, src_label_name))
|
||||||
src = src.remove_values_eq(-1)
|
src = src.remove_values_eq(-1)
|
||||||
bos_src = add_bos(src, bos_id=bos_id)
|
bos_src = add_bos(src, bos_id=bos_id)
|
||||||
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
|
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
|
||||||
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
|
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
|
tgt = k2.RaggedTensor(labels_shape, getattr(alignment, src_label_name))
|
||||||
# TODO: Do we need to remove 0s from tgt ?
|
# TODO: Do we need to remove 0s from tgt ?
|
||||||
tgt = tgt.remove_values_eq(-1)
|
tgt = tgt.remove_values_eq(-1)
|
||||||
tgt_eos = add_eos(tgt, eos_id=eos_id)
|
tgt_eos = add_eos(tgt, eos_id=eos_id)
|
||||||
@ -342,7 +352,7 @@ def prepare_conformer_lm_inputs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# find unmasked positions
|
# find unmasked positions
|
||||||
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
|
unmasked_positions = bos_masked_src_eos_pad[:, 1:] != 0
|
||||||
weight[unmasked_positions] = unmasked_weight
|
weight[unmasked_positions] = unmasked_weight
|
||||||
|
|
||||||
# set weights for paddings
|
# set weights for paddings
|
||||||
|
@ -69,8 +69,8 @@ def test_make_hyp_to_ref_map():
|
|||||||
row_splits = a.shape.row_splits(1)
|
row_splits = a.shape.row_splits(1)
|
||||||
repeat_map = make_hyp_to_ref_map(row_splits)
|
repeat_map = make_hyp_to_ref_map(row_splits)
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3,
|
expected = torch.tensor([0, 0, 1, 1, 2, 2, 3,
|
||||||
3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6]).to(repeat_map) # noqa
|
3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6]).to(repeat_map) # noqa
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert torch.all(torch.eq(repeat_map, expected))
|
assert torch.all(torch.eq(repeat_map, expected))
|
||||||
|
|
||||||
@ -80,9 +80,9 @@ def test_make_repeat_map():
|
|||||||
row_splits = a.shape.row_splits(1)
|
row_splits = a.shape.row_splits(1)
|
||||||
repeat_map = make_repeat_map(row_splits)
|
repeat_map = make_repeat_map(row_splits)
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2,
|
expected = torch.tensor([1, 2, 0, 2, 0, 1,
|
||||||
3, 4, 5, 6, 3, 4, 5, 6, 3, 4, 5, 6, # noqa
|
4, 5, 6, 3, 5, 6, 3, 4, 6, # noqa
|
||||||
3, 4, 5, 6]).to(repeat_map) # noqa
|
3, 4, 5]).to(repeat_map) # noqa
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert torch.all(torch.eq(repeat_map, expected))
|
assert torch.all(torch.eq(repeat_map, expected))
|
||||||
|
|
||||||
@ -95,11 +95,11 @@ def test_make_repeat():
|
|||||||
])
|
])
|
||||||
b = make_repeat(a)
|
b = make_repeat(a)
|
||||||
expected = k2.RaggedTensor([
|
expected = k2.RaggedTensor([
|
||||||
[[1, 3, 5], [2, 6], [1, 3, 5], [2, 6]],
|
[[2, 6], [1, 3, 5]],
|
||||||
[[1, 2, 3, 4], [2], [], [9, 10, 11],
|
[ [2], [], [9, 10, 11], # noqa
|
||||||
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
[1, 2, 3, 4], [], [9, 10, 11], # noqa
|
||||||
[1, 2, 3, 4], [2], [], [9, 10, 11],
|
[1, 2, 3, 4], [2], [9, 10, 11], # noqa
|
||||||
[1, 2, 3, 4], [2], [], [9, 10, 11]],
|
[1, 2, 3, 4], [2], [], ], # noqa
|
||||||
])
|
])
|
||||||
# fmt: on
|
# fmt: on
|
||||||
assert str(b) == str(expected)
|
assert str(b) == str(expected)
|
||||||
@ -116,19 +116,24 @@ def test_compute_alignment():
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
shape = k2.RaggedShape("[[x x x] [x x]]")
|
shape = k2.RaggedShape("[[x x x] [x x]]")
|
||||||
alignment = compute_alignment(tokens, shape)
|
alignment = compute_alignment(tokens, shape)
|
||||||
|
print(alignment.ref_labels)
|
||||||
|
print(alignment.hyp_labels)
|
||||||
|
print(alignment.labels)
|
||||||
(
|
(
|
||||||
masked_src,
|
masked_src,
|
||||||
src,
|
src,
|
||||||
tgt,
|
tgt,
|
||||||
src_key_padding_mask,
|
src_key_padding_mask,
|
||||||
weight,
|
weight,
|
||||||
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
|
) = prepare_conformer_lm_inputs(
|
||||||
|
alignment, bos_id=10, eos_id=20, blank_id=0, src_label_name="hyp_labels"
|
||||||
|
)
|
||||||
|
|
||||||
# print("masked src", masked_src)
|
print("masked src", masked_src)
|
||||||
# print("src", src)
|
print("src", src)
|
||||||
# print("tgt", tgt)
|
print("tgt", tgt)
|
||||||
# print("src_key_padding_mask", src_key_padding_mask)
|
print("src_key_padding_mask", src_key_padding_mask)
|
||||||
# print("weight", weight)
|
print("weight", weight)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user