mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
update api for RaggedTensor (#45)
* Fix code style * update k2 version in CI * fix compile hlg
This commit is contained in:
parent
a2be2896a9
commit
9a6e0489c8
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip black flake8
|
python3 -m pip install --upgrade pip black flake8
|
||||||
python3 -m pip install -U pip
|
python3 -m pip install -U pip
|
||||||
python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
|
python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
|
||||||
python3 -m pip install torchaudio==0.7.2
|
python3 -m pip install torchaudio==0.7.2
|
||||||
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
|
|
||||||
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -32,7 +32,7 @@ jobs:
|
|||||||
os: [ubuntu-18.04, macos-10.15]
|
os: [ubuntu-18.04, macos-10.15]
|
||||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||||
torch: ["1.8.1"]
|
torch: ["1.8.1"]
|
||||||
k2-version: ["1.7.dev20210908"]
|
k2-version: ["1.7.dev20210914"]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
|
||||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||||
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
|
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
LG = k2.remove_epsilon(LG)
|
LG = k2.remove_epsilon(LG)
|
||||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
@ -81,7 +81,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||||
|
|
||||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||||
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
|
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||||
|
|
||||||
LG = k2.remove_epsilon(LG)
|
LG = k2.remove_epsilon(LG)
|
||||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||||
|
@ -221,7 +221,8 @@ def nbest_decoding(
|
|||||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
|
word_seq = word_seq.remove_axis(1)
|
||||||
|
|
||||||
# Remove 0 (epsilon) and -1 from word_seq
|
# Remove 0 (epsilon) and -1 from word_seq
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
@ -300,7 +301,7 @@ def nbest_decoding(
|
|||||||
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
|
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
|
||||||
# aux_labels is also a k2.RaggedTensor with 2 axes
|
# aux_labels is also a k2.RaggedTensor with 2 axes
|
||||||
aux_labels, _ = lattice.aux_labels.index(
|
aux_labels, _ = lattice.aux_labels.index(
|
||||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path_fsa = k2.linear_fsa(labels)
|
best_path_fsa = k2.linear_fsa(labels)
|
||||||
@ -430,7 +431,8 @@ def rescore_with_n_best_list(
|
|||||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
|
word_seq = word_seq.remove_axis(1)
|
||||||
|
|
||||||
# Remove epsilons and -1 from word_seq
|
# Remove epsilons and -1 from word_seq
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
@ -520,7 +522,7 @@ def rescore_with_n_best_list(
|
|||||||
# aux_labels is also a k2.RaggedTensor with 2 axes
|
# aux_labels is also a k2.RaggedTensor with 2 axes
|
||||||
|
|
||||||
aux_labels, _ = lattice.aux_labels.index(
|
aux_labels, _ = lattice.aux_labels.index(
|
||||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path_fsa = k2.linear_fsa(labels)
|
best_path_fsa = k2.linear_fsa(labels)
|
||||||
@ -666,7 +668,8 @@ def nbest_oracle(
|
|||||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
|
word_seq = word_seq.remove_axis(1)
|
||||||
|
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
unique_word_seq, _, _ = word_seq.unique(
|
unique_word_seq, _, _ = word_seq.unique(
|
||||||
@ -757,7 +760,8 @@ def rescore_with_attention_decoder(
|
|||||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||||
else:
|
else:
|
||||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
word_seq = lattice.aux_labels.index(path)
|
||||||
|
word_seq = word_seq.remove_axis(1)
|
||||||
|
|
||||||
# Remove epsilons and -1 from word_seq
|
# Remove epsilons and -1 from word_seq
|
||||||
word_seq = word_seq.remove_values_leq(0)
|
word_seq = word_seq.remove_values_leq(0)
|
||||||
@ -810,7 +814,8 @@ def rescore_with_attention_decoder(
|
|||||||
if isinstance(lattice.tokens, torch.Tensor):
|
if isinstance(lattice.tokens, torch.Tensor):
|
||||||
token_seq = k2.ragged.index(lattice.tokens, path)
|
token_seq = k2.ragged.index(lattice.tokens, path)
|
||||||
else:
|
else:
|
||||||
token_seq = lattice.tokens.index(path, remove_axis=True)
|
token_seq = lattice.tokens.index(path)
|
||||||
|
token_seq = token_seq.remove_axis(1)
|
||||||
|
|
||||||
# Remove epsilons and -1 from token_seq
|
# Remove epsilons and -1 from token_seq
|
||||||
token_seq = token_seq.remove_values_leq(0)
|
token_seq = token_seq.remove_values_leq(0)
|
||||||
@ -890,10 +895,12 @@ def rescore_with_attention_decoder(
|
|||||||
labels = labels.remove_values_eq(-1)
|
labels = labels.remove_values_eq(-1)
|
||||||
|
|
||||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||||
aux_labels = k2.index_select(lattice.aux_labels, best_path.data)
|
aux_labels = k2.index_select(
|
||||||
|
lattice.aux_labels, best_path.values
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
aux_labels, _ = lattice.aux_labels.index(
|
aux_labels, _ = lattice.aux_labels.index(
|
||||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||||
)
|
)
|
||||||
|
|
||||||
best_path_fsa = k2.linear_fsa(labels)
|
best_path_fsa = k2.linear_fsa(labels)
|
||||||
|
@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
|||||||
# remove the states and arcs axes.
|
# remove the states and arcs axes.
|
||||||
aux_shape = aux_shape.remove_axis(1)
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
aux_shape = aux_shape.remove_axis(1)
|
aux_shape = aux_shape.remove_axis(1)
|
||||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data)
|
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||||
else:
|
else:
|
||||||
# remove axis corresponding to states.
|
# remove axis corresponding to states.
|
||||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user