mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Finish preparing the inputs for conformer lm from an nbest object.
This commit is contained in:
parent
1b9e4f0fea
commit
3441634f34
@ -42,6 +42,28 @@ import torch
|
||||
from icefall.decode import Nbest
|
||||
|
||||
|
||||
def make_key_padding_mask(lengths: torch.Tensor):
|
||||
"""
|
||||
TODO: add documentation
|
||||
|
||||
>>> make_key_padding_mask(torch.tensor([3, 1, 4]))
|
||||
tensor([[False, False, False, True],
|
||||
[False, True, True, True],
|
||||
[False, False, False, False]])
|
||||
"""
|
||||
assert lengths.dim() == 1
|
||||
|
||||
bs = lengths.numel()
|
||||
max_len = lengths.max().item()
|
||||
device = lengths.device
|
||||
seq_range = torch.arange(0, max_len, device=device)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
|
||||
|
||||
def concat(
|
||||
ragged: k2.RaggedTensor, value: int, direction: str
|
||||
) -> k2.RaggedTensor:
|
||||
@ -77,13 +99,12 @@ def concat(
|
||||
|
||||
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
||||
pad_values = torch.full(
|
||||
size=(ragged.tot_size(0),),
|
||||
size=(ragged.tot_size(0), 1),
|
||||
fill_value=value,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
pad_shape = k2.ragged.regular_ragged_shape(ragged.tot_size(0), 1).to(device)
|
||||
pad = k2.RaggedTensor(pad_shape, pad_values)
|
||||
pad = k2.RaggedTensor(pad_values)
|
||||
|
||||
if direction == "left":
|
||||
ans = k2.ragged.cat([pad, ragged], axis=1)
|
||||
@ -233,10 +254,10 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
||||
return k2.RaggedTensor(repeated).to(tokens.device)
|
||||
|
||||
|
||||
def compute_alignments(
|
||||
def compute_alignment(
|
||||
tokens: k2.RaggedTensor,
|
||||
shape: k2.RaggedShape,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> k2.Fsa:
|
||||
"""
|
||||
TODO: Add documentation.
|
||||
|
||||
@ -263,9 +284,90 @@ def compute_alignments(
|
||||
return alignment
|
||||
|
||||
|
||||
def prepare_conformer_lm_inputs(
|
||||
alignment: k2.Fsa,
|
||||
bos_id: int,
|
||||
eos_id: int,
|
||||
blank_id: int,
|
||||
unmasked_weight: float = 0.25,
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
"""
|
||||
TODO: add documentation.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
It is computed by :func:`compute_alignment`
|
||||
"""
|
||||
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||
# we remove axis 1, i.e., state, here
|
||||
labels_shape = alignment.arcs.shape().remove_axis(1)
|
||||
|
||||
masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous())
|
||||
masked_src = masked_src.remove_values_eq(-1)
|
||||
bos_masked_src = add_bos(masked_src, bos_id=bos_id)
|
||||
bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id)
|
||||
bos_masked_src_eos_pad = bos_masked_src_eos.pad(
|
||||
mode="constant", padding_value=blank_id
|
||||
)
|
||||
|
||||
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
|
||||
src = src.remove_values_eq(-1)
|
||||
bos_src = add_bos(src, bos_id=bos_id)
|
||||
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
|
||||
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
|
||||
# TODO: Do we need to remove 0s from tgt ?
|
||||
tgt = tgt.remove_values_eq(-1)
|
||||
tgt_eos = add_eos(tgt, eos_id=eos_id)
|
||||
|
||||
# add a blank here since tgt_eos does not start with bos
|
||||
# assume blank id is 0
|
||||
tgt_eos = add_eos(tgt_eos, eos_id=blank_id)
|
||||
|
||||
row_splits = tgt_eos.shape.row_splits(1)
|
||||
lengths = row_splits[1:] - row_splits[:-1]
|
||||
src_key_padding_mask = make_key_padding_mask(lengths)
|
||||
|
||||
tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
weight = torch.full(
|
||||
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
||||
fill_value=1,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# find unmasked positions
|
||||
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
|
||||
weight[unmasked_positions] = unmasked_weight
|
||||
|
||||
# set weights for paddings
|
||||
weight[src_key_padding_mask[:, 1:]] = 0
|
||||
zeros = torch.zeros(weight.size(0), 1).to(weight)
|
||||
|
||||
weight = torch.cat((weight, zeros), dim=1)
|
||||
|
||||
# all other positions are assumed to be masked and
|
||||
# have the default weight 1
|
||||
|
||||
return (
|
||||
bos_masked_src_eos_pad,
|
||||
bos_src_eos_pad,
|
||||
tgt_eos_pad,
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
)
|
||||
|
||||
|
||||
def conformer_lm_rescore(
|
||||
nbest: Nbest,
|
||||
model: torch.nn.Module,
|
||||
bos_id: int,
|
||||
eos_id: int,
|
||||
blank_id: int,
|
||||
unmasked_weight: float = 0.25,
|
||||
# TODO: add other arguments if needed
|
||||
) -> k2.RaggedTensor:
|
||||
"""Rescore an Nbest object with a conformer_lm model.
|
||||
@ -281,4 +383,28 @@ def conformer_lm_rescore(
|
||||
contained in the nbest. Its shape equals to `nbest.shape`.
|
||||
"""
|
||||
assert hasattr(nbest.fsa, "tokens")
|
||||
# TODO:
|
||||
utt_path_shape = nbest.shape
|
||||
# nbest.fsa.arcs.shape() has axes [path][state][arc]
|
||||
# We remove the state axis here
|
||||
path_token_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
|
||||
path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens)
|
||||
path_token = path_token.remove_values_leq(0)
|
||||
|
||||
alignment = compute_alignment(path_token, utt_path_shape)
|
||||
(
|
||||
masked_src,
|
||||
src,
|
||||
tgt,
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
) = prepare_conformer_lm_inputs(
|
||||
alignment,
|
||||
bos_id=bos_id,
|
||||
eos_id=eos_id,
|
||||
blank_id=blank_id,
|
||||
unmasked_weight=unmasked_weight,
|
||||
)
|
||||
return masked_src, src, tgt, src_key_padding_mask, weight
|
||||
# TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight
|
||||
# to the given model
|
||||
|
@ -17,13 +17,16 @@
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.decode import Nbest
|
||||
from icefall.lm.rescore import (
|
||||
add_bos,
|
||||
add_eos,
|
||||
make_repeat_map,
|
||||
compute_alignment,
|
||||
conformer_lm_rescore,
|
||||
make_hyp_to_ref_map,
|
||||
make_repeat,
|
||||
compute_alignments,
|
||||
make_repeat_map,
|
||||
prepare_conformer_lm_inputs,
|
||||
)
|
||||
|
||||
|
||||
@ -103,20 +106,51 @@ def test_make_repeat():
|
||||
assert str(b) == str(expected)
|
||||
|
||||
|
||||
def test_compute_alignments():
|
||||
def test_compute_alignment():
|
||||
# fmt: off
|
||||
tokens = k2.RaggedTensor([
|
||||
# utt 0
|
||||
[1, 3], [8], [2],
|
||||
[1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2],
|
||||
# utt 1
|
||||
[1, 5], [9],
|
||||
[2, 3], [2],
|
||||
])
|
||||
shape = k2.RaggedShape('[[x x x] [x x]]')
|
||||
# fmt: on
|
||||
alignment = compute_alignments(tokens, shape)
|
||||
print("maksed_src:", alignment.labels)
|
||||
print("src:", alignment.hyp_labels)
|
||||
print("tgt:", alignment.ref_labels)
|
||||
shape = k2.RaggedShape("[[x x x] [x x]]")
|
||||
alignment = compute_alignment(tokens, shape)
|
||||
(
|
||||
masked_src,
|
||||
src,
|
||||
tgt,
|
||||
src_key_padding_mask,
|
||||
weight,
|
||||
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
|
||||
|
||||
# print("masked src", masked_src)
|
||||
# print("src", src)
|
||||
# print("tgt", tgt)
|
||||
# print("src_key_padding_mask", src_key_padding_mask)
|
||||
# print("weight", weight)
|
||||
|
||||
|
||||
def test_conformer_lm_rescore():
|
||||
path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5])
|
||||
path01 = k2.linear_fsa([1, 0, 5, 0])
|
||||
path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2])
|
||||
path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2])
|
||||
path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3])
|
||||
|
||||
fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12])
|
||||
fsa.tokens = fsa.labels.clone()
|
||||
shape = k2.RaggedShape("[[x x] [x x x]]")
|
||||
nbest = Nbest(fsa, shape)
|
||||
masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore(
|
||||
nbest, model=None, bos_id=10, eos_id=20, blank_id=0
|
||||
)
|
||||
print("masked src", masked_src)
|
||||
print("src", src)
|
||||
print("tgt", tgt)
|
||||
print("src_key_padding_mask", src_key_padding_mask)
|
||||
print("weight", weight)
|
||||
|
||||
|
||||
def main():
|
||||
@ -126,7 +160,8 @@ def main():
|
||||
test_make_repeat_map()
|
||||
test_make_hyp_to_ref_map()
|
||||
test_make_repeat()
|
||||
test_compute_alignments()
|
||||
test_compute_alignment()
|
||||
test_conformer_lm_rescore()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user