mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Finish preparing the inputs for conformer lm from an nbest object.
This commit is contained in:
parent
1b9e4f0fea
commit
3441634f34
@ -42,6 +42,28 @@ import torch
|
|||||||
from icefall.decode import Nbest
|
from icefall.decode import Nbest
|
||||||
|
|
||||||
|
|
||||||
|
def make_key_padding_mask(lengths: torch.Tensor):
|
||||||
|
"""
|
||||||
|
TODO: add documentation
|
||||||
|
|
||||||
|
>>> make_key_padding_mask(torch.tensor([3, 1, 4]))
|
||||||
|
tensor([[False, False, False, True],
|
||||||
|
[False, True, True, True],
|
||||||
|
[False, False, False, False]])
|
||||||
|
"""
|
||||||
|
assert lengths.dim() == 1
|
||||||
|
|
||||||
|
bs = lengths.numel()
|
||||||
|
max_len = lengths.max().item()
|
||||||
|
device = lengths.device
|
||||||
|
seq_range = torch.arange(0, max_len, device=device)
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||||
|
|
||||||
|
seq_length_expand = lengths.unsqueeze(-1)
|
||||||
|
mask = seq_range_expand >= seq_length_expand
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def concat(
|
def concat(
|
||||||
ragged: k2.RaggedTensor, value: int, direction: str
|
ragged: k2.RaggedTensor, value: int, direction: str
|
||||||
) -> k2.RaggedTensor:
|
) -> k2.RaggedTensor:
|
||||||
@ -77,13 +99,12 @@ def concat(
|
|||||||
|
|
||||||
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
|
||||||
pad_values = torch.full(
|
pad_values = torch.full(
|
||||||
size=(ragged.tot_size(0),),
|
size=(ragged.tot_size(0), 1),
|
||||||
fill_value=value,
|
fill_value=value,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
pad_shape = k2.ragged.regular_ragged_shape(ragged.tot_size(0), 1).to(device)
|
pad = k2.RaggedTensor(pad_values)
|
||||||
pad = k2.RaggedTensor(pad_shape, pad_values)
|
|
||||||
|
|
||||||
if direction == "left":
|
if direction == "left":
|
||||||
ans = k2.ragged.cat([pad, ragged], axis=1)
|
ans = k2.ragged.cat([pad, ragged], axis=1)
|
||||||
@ -233,10 +254,10 @@ def make_repeat(tokens: k2.RaggedTensor) -> k2.RaggedTensor:
|
|||||||
return k2.RaggedTensor(repeated).to(tokens.device)
|
return k2.RaggedTensor(repeated).to(tokens.device)
|
||||||
|
|
||||||
|
|
||||||
def compute_alignments(
|
def compute_alignment(
|
||||||
tokens: k2.RaggedTensor,
|
tokens: k2.RaggedTensor,
|
||||||
shape: k2.RaggedShape,
|
shape: k2.RaggedShape,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> k2.Fsa:
|
||||||
"""
|
"""
|
||||||
TODO: Add documentation.
|
TODO: Add documentation.
|
||||||
|
|
||||||
@ -263,9 +284,90 @@ def compute_alignments(
|
|||||||
return alignment
|
return alignment
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_conformer_lm_inputs(
|
||||||
|
alignment: k2.Fsa,
|
||||||
|
bos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
blank_id: int,
|
||||||
|
unmasked_weight: float = 0.25,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
TODO: add documentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignments:
|
||||||
|
It is computed by :func:`compute_alignment`
|
||||||
|
"""
|
||||||
|
# alignment.arcs.shape has axes [fsa][state][arc]
|
||||||
|
# we remove axis 1, i.e., state, here
|
||||||
|
labels_shape = alignment.arcs.shape().remove_axis(1)
|
||||||
|
|
||||||
|
masked_src = k2.RaggedTensor(labels_shape, alignment.labels.contiguous())
|
||||||
|
masked_src = masked_src.remove_values_eq(-1)
|
||||||
|
bos_masked_src = add_bos(masked_src, bos_id=bos_id)
|
||||||
|
bos_masked_src_eos = add_eos(bos_masked_src, eos_id=eos_id)
|
||||||
|
bos_masked_src_eos_pad = bos_masked_src_eos.pad(
|
||||||
|
mode="constant", padding_value=blank_id
|
||||||
|
)
|
||||||
|
|
||||||
|
src = k2.RaggedTensor(labels_shape, alignment.hyp_labels)
|
||||||
|
src = src.remove_values_eq(-1)
|
||||||
|
bos_src = add_bos(src, bos_id=bos_id)
|
||||||
|
bos_src_eos = add_eos(bos_src, eos_id=eos_id)
|
||||||
|
bos_src_eos_pad = bos_src_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
tgt = k2.RaggedTensor(labels_shape, alignment.ref_labels)
|
||||||
|
# TODO: Do we need to remove 0s from tgt ?
|
||||||
|
tgt = tgt.remove_values_eq(-1)
|
||||||
|
tgt_eos = add_eos(tgt, eos_id=eos_id)
|
||||||
|
|
||||||
|
# add a blank here since tgt_eos does not start with bos
|
||||||
|
# assume blank id is 0
|
||||||
|
tgt_eos = add_eos(tgt_eos, eos_id=blank_id)
|
||||||
|
|
||||||
|
row_splits = tgt_eos.shape.row_splits(1)
|
||||||
|
lengths = row_splits[1:] - row_splits[:-1]
|
||||||
|
src_key_padding_mask = make_key_padding_mask(lengths)
|
||||||
|
|
||||||
|
tgt_eos_pad = tgt_eos.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
weight = torch.full(
|
||||||
|
(tgt_eos_pad.size(0), tgt_eos_pad.size(1) - 1),
|
||||||
|
fill_value=1,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# find unmasked positions
|
||||||
|
unmasked_positions = bos_src_eos_pad[:, 1:] == tgt_eos_pad[:, :-1]
|
||||||
|
weight[unmasked_positions] = unmasked_weight
|
||||||
|
|
||||||
|
# set weights for paddings
|
||||||
|
weight[src_key_padding_mask[:, 1:]] = 0
|
||||||
|
zeros = torch.zeros(weight.size(0), 1).to(weight)
|
||||||
|
|
||||||
|
weight = torch.cat((weight, zeros), dim=1)
|
||||||
|
|
||||||
|
# all other positions are assumed to be masked and
|
||||||
|
# have the default weight 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
bos_masked_src_eos_pad,
|
||||||
|
bos_src_eos_pad,
|
||||||
|
tgt_eos_pad,
|
||||||
|
src_key_padding_mask,
|
||||||
|
weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def conformer_lm_rescore(
|
def conformer_lm_rescore(
|
||||||
nbest: Nbest,
|
nbest: Nbest,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
|
bos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
blank_id: int,
|
||||||
|
unmasked_weight: float = 0.25,
|
||||||
# TODO: add other arguments if needed
|
# TODO: add other arguments if needed
|
||||||
) -> k2.RaggedTensor:
|
) -> k2.RaggedTensor:
|
||||||
"""Rescore an Nbest object with a conformer_lm model.
|
"""Rescore an Nbest object with a conformer_lm model.
|
||||||
@ -281,4 +383,28 @@ def conformer_lm_rescore(
|
|||||||
contained in the nbest. Its shape equals to `nbest.shape`.
|
contained in the nbest. Its shape equals to `nbest.shape`.
|
||||||
"""
|
"""
|
||||||
assert hasattr(nbest.fsa, "tokens")
|
assert hasattr(nbest.fsa, "tokens")
|
||||||
# TODO:
|
utt_path_shape = nbest.shape
|
||||||
|
# nbest.fsa.arcs.shape() has axes [path][state][arc]
|
||||||
|
# We remove the state axis here
|
||||||
|
path_token_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||||
|
|
||||||
|
path_token = k2.RaggedTensor(path_token_shape, nbest.fsa.tokens)
|
||||||
|
path_token = path_token.remove_values_leq(0)
|
||||||
|
|
||||||
|
alignment = compute_alignment(path_token, utt_path_shape)
|
||||||
|
(
|
||||||
|
masked_src,
|
||||||
|
src,
|
||||||
|
tgt,
|
||||||
|
src_key_padding_mask,
|
||||||
|
weight,
|
||||||
|
) = prepare_conformer_lm_inputs(
|
||||||
|
alignment,
|
||||||
|
bos_id=bos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
blank_id=blank_id,
|
||||||
|
unmasked_weight=unmasked_weight,
|
||||||
|
)
|
||||||
|
return masked_src, src, tgt, src_key_padding_mask, weight
|
||||||
|
# TODO: pass masked_src, src, tgt, src_key_padding_mask, and weight
|
||||||
|
# to the given model
|
||||||
|
@ -17,13 +17,16 @@
|
|||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from icefall.decode import Nbest
|
||||||
from icefall.lm.rescore import (
|
from icefall.lm.rescore import (
|
||||||
add_bos,
|
add_bos,
|
||||||
add_eos,
|
add_eos,
|
||||||
make_repeat_map,
|
compute_alignment,
|
||||||
|
conformer_lm_rescore,
|
||||||
make_hyp_to_ref_map,
|
make_hyp_to_ref_map,
|
||||||
make_repeat,
|
make_repeat,
|
||||||
compute_alignments,
|
make_repeat_map,
|
||||||
|
prepare_conformer_lm_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -103,20 +106,51 @@ def test_make_repeat():
|
|||||||
assert str(b) == str(expected)
|
assert str(b) == str(expected)
|
||||||
|
|
||||||
|
|
||||||
def test_compute_alignments():
|
def test_compute_alignment():
|
||||||
# fmt: off
|
# fmt: off
|
||||||
tokens = k2.RaggedTensor([
|
tokens = k2.RaggedTensor([
|
||||||
# utt 0
|
# utt 0
|
||||||
[1, 3], [8], [2],
|
[1, 3, 5, 8], [1, 5, 8], [2, 8, 3, 2],
|
||||||
# utt 1
|
# utt 1
|
||||||
[1, 5], [9],
|
[2, 3], [2],
|
||||||
])
|
])
|
||||||
shape = k2.RaggedShape('[[x x x] [x x]]')
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
alignment = compute_alignments(tokens, shape)
|
shape = k2.RaggedShape("[[x x x] [x x]]")
|
||||||
print("maksed_src:", alignment.labels)
|
alignment = compute_alignment(tokens, shape)
|
||||||
print("src:", alignment.hyp_labels)
|
(
|
||||||
print("tgt:", alignment.ref_labels)
|
masked_src,
|
||||||
|
src,
|
||||||
|
tgt,
|
||||||
|
src_key_padding_mask,
|
||||||
|
weight,
|
||||||
|
) = prepare_conformer_lm_inputs(alignment, bos_id=10, eos_id=20, blank_id=0)
|
||||||
|
|
||||||
|
# print("masked src", masked_src)
|
||||||
|
# print("src", src)
|
||||||
|
# print("tgt", tgt)
|
||||||
|
# print("src_key_padding_mask", src_key_padding_mask)
|
||||||
|
# print("weight", weight)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conformer_lm_rescore():
|
||||||
|
path00 = k2.linear_fsa([1, 2, 0, 3, 0, 5])
|
||||||
|
path01 = k2.linear_fsa([1, 0, 5, 0])
|
||||||
|
path10 = k2.linear_fsa([9, 8, 0, 3, 0, 2])
|
||||||
|
path11 = k2.linear_fsa([9, 8, 0, 0, 3, 2])
|
||||||
|
path12 = k2.linear_fsa([9, 0, 8, 4, 0, 2, 3])
|
||||||
|
|
||||||
|
fsa = k2.Fsa.from_fsas([path00, path01, path10, path11, path12])
|
||||||
|
fsa.tokens = fsa.labels.clone()
|
||||||
|
shape = k2.RaggedShape("[[x x] [x x x]]")
|
||||||
|
nbest = Nbest(fsa, shape)
|
||||||
|
masked_src, src, tgt, src_key_padding_mask, weight = conformer_lm_rescore(
|
||||||
|
nbest, model=None, bos_id=10, eos_id=20, blank_id=0
|
||||||
|
)
|
||||||
|
print("masked src", masked_src)
|
||||||
|
print("src", src)
|
||||||
|
print("tgt", tgt)
|
||||||
|
print("src_key_padding_mask", src_key_padding_mask)
|
||||||
|
print("weight", weight)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -126,7 +160,8 @@ def main():
|
|||||||
test_make_repeat_map()
|
test_make_repeat_map()
|
||||||
test_make_hyp_to_ref_map()
|
test_make_hyp_to_ref_map()
|
||||||
test_make_repeat()
|
test_make_repeat()
|
||||||
test_compute_alignments()
|
test_compute_alignment()
|
||||||
|
test_conformer_lm_rescore()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user