update api for RaggedTensor (#45)

* Fix code style

* update k2 version in CI

* fix compile hlg
This commit is contained in:
Wei Kang 2021-09-14 16:39:56 +08:00 committed by GitHub
parent a2be2896a9
commit 9a6e0489c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 21 additions and 14 deletions

View File

@ -56,7 +56,7 @@ jobs:
run: |
python3 -m pip install --upgrade pip black flake8
python3 -m pip install -U pip
python3 -m pip install k2==1.7.dev20210908+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install k2==1.7.dev20210914+cpu.torch1.7.1 -f https://k2-fsa.org/nightly/
python3 -m pip install torchaudio==0.7.2
python3 -m pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -32,7 +32,7 @@ jobs:
os: [ubuntu-18.04, macos-10.15]
python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"]
k2-version: ["1.7.dev20210908"]
k2-version: ["1.7.dev20210914"]
fail-fast: false

View File

@ -103,7 +103,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

View File

@ -81,7 +81,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
LG.labels[LG.labels >= first_token_disambig_id] = 0
assert isinstance(LG.aux_labels, k2.RaggedTensor)
LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
LG = k2.remove_epsilon(LG)
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")

View File

@ -221,7 +221,8 @@ def nbest_decoding(
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
# Remove 0 (epsilon) and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
@ -300,7 +301,7 @@ def nbest_decoding(
# lattice.aux_labels is a k2.RaggedTensor with 2 axes, so
# aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
indexes=best_path.values, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels)
@ -430,7 +431,8 @@ def rescore_with_n_best_list(
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
@ -520,7 +522,7 @@ def rescore_with_n_best_list(
# aux_labels is also a k2.RaggedTensor with 2 axes
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
indexes=best_path.values, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels)
@ -666,7 +668,8 @@ def nbest_oracle(
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
word_seq = word_seq.remove_values_leq(0)
unique_word_seq, _, _ = word_seq.unique(
@ -757,7 +760,8 @@ def rescore_with_attention_decoder(
if isinstance(lattice.aux_labels, torch.Tensor):
word_seq = k2.ragged.index(lattice.aux_labels, path)
else:
word_seq = lattice.aux_labels.index(path, remove_axis=True)
word_seq = lattice.aux_labels.index(path)
word_seq = word_seq.remove_axis(1)
# Remove epsilons and -1 from word_seq
word_seq = word_seq.remove_values_leq(0)
@ -810,7 +814,8 @@ def rescore_with_attention_decoder(
if isinstance(lattice.tokens, torch.Tensor):
token_seq = k2.ragged.index(lattice.tokens, path)
else:
token_seq = lattice.tokens.index(path, remove_axis=True)
token_seq = lattice.tokens.index(path)
token_seq = token_seq.remove_axis(1)
# Remove epsilons and -1 from token_seq
token_seq = token_seq.remove_values_leq(0)
@ -890,10 +895,12 @@ def rescore_with_attention_decoder(
labels = labels.remove_values_eq(-1)
if isinstance(lattice.aux_labels, torch.Tensor):
aux_labels = k2.index_select(lattice.aux_labels, best_path.data)
aux_labels = k2.index_select(
lattice.aux_labels, best_path.values
)
else:
aux_labels, _ = lattice.aux_labels.index(
indexes=best_path.data, axis=0, need_value_indexes=False
indexes=best_path.values, axis=0, need_value_indexes=False
)
best_path_fsa = k2.linear_fsa(labels)

View File

@ -207,7 +207,7 @@ def get_texts(best_paths: k2.Fsa) -> List[List[int]]:
# remove the states and arcs axes.
aux_shape = aux_shape.remove_axis(1)
aux_shape = aux_shape.remove_axis(1)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.data)
aux_labels = k2.RaggedTensor(aux_shape, aux_labels.values)
else:
# remove axis corresponding to states.
aux_shape = best_paths.arcs.shape().remove_axis(1)