mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
update api for RaggedTensor (#45)
* Fix code style * update k2 version in CI * fix compile hlg
This commit is contained in:
parent
a2be2896a9
commit
9a6e0489c8
2
.github/workflows/run-yesno-recipe.yml
vendored
2
.github/workflows/run-yesno-recipe.yml
vendored
@ -56,7 +56,7 @@ jobs:
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip black flake8
|
||||
python3 -m pip install -U pip
|
||||
python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
|
||||
python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
|
||||
python3 -m pip install torchaudio==0.7.2
|
||||
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -32,7 +32,7 @@ jobs:
|
||||
os: [ubuntu-18.04, macos-10.15]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.7.dev20210908"]
|
||||
k2-version: ["1.7.dev20210914"]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
|
@ -103,7 +103,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
@ -81,7 +81,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
|
||||
assert isinstance(LG.aux_labels, k2.RaggedTensor)
|
||||
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
@ -221,7 +221,8 @@ def nbest_decoding(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
# Remove 0 (epsilon) and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -300,7 +301,7 @@ def nbest_decoding(
|
||||
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
|
||||
# aux_labels is also a k2.RaggedTensor with 2 axes
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
||||
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
@ -430,7 +431,8 @@ def rescore_with_n_best_list(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -520,7 +522,7 @@ def rescore_with_n_best_list(
|
||||
# aux_labels is also a k2.RaggedTensor with 2 axes
|
||||
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
||||
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
@ -666,7 +668,8 @@ def nbest_oracle(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
unique_word_seq, _, _ = word_seq.unique(
|
||||
@ -757,7 +760,8 @@ def rescore_with_attention_decoder(
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
word_seq = k2.ragged.index(lattice.aux_labels, path)
|
||||
else:
|
||||
word_seq = lattice.aux_labels.index(path, remove_axis=True)
|
||||
word_seq = lattice.aux_labels.index(path)
|
||||
word_seq = word_seq.remove_axis(1)
|
||||
|
||||
# Remove epsilons and -1 from word_seq
|
||||
word_seq = word_seq.remove_values_leq(0)
|
||||
@ -810,7 +814,8 @@ def rescore_with_attention_decoder(
|
||||
if isinstance(lattice.tokens, torch.Tensor):
|
||||
token_seq = k2.ragged.index(lattice.tokens, path)
|
||||
else:
|
||||
token_seq = lattice.tokens.index(path, remove_axis=True)
|
||||
token_seq = lattice.tokens.index(path)
|
||||
token_seq = token_seq.remove_axis(1)
|
||||
|
||||
# Remove epsilons and -1 from token_seq
|
||||
token_seq = token_seq.remove_values_leq(0)
|
||||
@ -890,10 +895,12 @@ def rescore_with_attention_decoder(
|
||||
labels = labels.remove_values_eq(-1)
|
||||
|
||||
if isinstance(lattice.aux_labels, torch.Tensor):
|
||||
aux_labels = k2.index_select(lattice.aux_labels, best_path.data)
|
||||
aux_labels = k2.index_select(
|
||||
lattice.aux_labels, best_path.values
|
||||
)
|
||||
else:
|
||||
aux_labels, _ = lattice.aux_labels.index(
|
||||
indexes=best_path.data, axis=0, need_value_indexes=False
|
||||
indexes=best_path.values, axis=0, need_value_indexes=False
|
||||
)
|
||||
|
||||
best_path_fsa = k2.linear_fsa(labels)
|
||||
|
@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
|
||||
# remove the states and arcs axes.
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_shape = aux_shape.remove_axis(1)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data)
|
||||
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
|
||||
else:
|
||||
# remove axis corresponding to states.
|
||||
aux_shape = best_paths.arcs.shape().remove_axis(1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user