mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Minor fixes for streaming decoding
This commit is contained in:
parent
88ed814197
commit
6a937ffa40
@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -45,6 +46,7 @@ class DecodeStream(object):
|
|||||||
assert device == decoding_graph.device
|
assert device == decoding_graph.device
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
|
self.LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
self.states = initial_states
|
self.states = initial_states
|
||||||
|
|
||||||
@ -52,6 +54,7 @@ class DecodeStream(object):
|
|||||||
self.features: torch.Tensor = None
|
self.features: torch.Tensor = None
|
||||||
# how many frames have been processed. (before subsampling).
|
# how many frames have been processed. (before subsampling).
|
||||||
# we only modify this value in `func:get_feature_frames`.
|
# we only modify this value in `func:get_feature_frames`.
|
||||||
|
self.num_frames = 0
|
||||||
self.num_processed_frames: int = 0
|
self.num_processed_frames: int = 0
|
||||||
self._done: bool = False
|
self._done: bool = False
|
||||||
# The transcript of current utterance.
|
# The transcript of current utterance.
|
||||||
@ -62,7 +65,11 @@ class DecodeStream(object):
|
|||||||
# how many frames have been processed, after subsampling (i.e. a
|
# how many frames have been processed, after subsampling (i.e. a
|
||||||
# cumulative sum of the second return value of
|
# cumulative sum of the second return value of
|
||||||
# encoder.streaming_forward
|
# encoder.streaming_forward
|
||||||
self.feature_len: int = 0
|
self.done_frames: int = 0
|
||||||
|
|
||||||
|
self.pad_length = (
|
||||||
|
params.right_context + 2
|
||||||
|
) * params.subsampling_factor + 3
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
self.hyp = [params.blank_id] * params.context_size
|
self.hyp = [params.blank_id] * params.context_size
|
||||||
@ -86,27 +93,32 @@ class DecodeStream(object):
|
|||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set features tensor of current utterance."""
|
"""Set features tensor of current utterance."""
|
||||||
self.features = features
|
assert features.dim() == 2, features.dim()
|
||||||
|
self.num_frames = features.size(0)
|
||||||
|
# tail padding
|
||||||
|
self.features = torch.nn.functional.pad(
|
||||||
|
features,
|
||||||
|
(0, 0, 0, self.pad_length),
|
||||||
|
mode="constant",
|
||||||
|
value=self.LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]:
|
||||||
"""Consume chunk_size frames of features"""
|
"""Consume chunk_size frames of features"""
|
||||||
# plus 3 here because we subsampling features with
|
# plus 3 here because we subsampling features with
|
||||||
# lengths = ((x_lens - 1) // 2 - 1) // 2
|
# lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
ret_chunk_size = min(
|
update_length = min(
|
||||||
self.features.size(0) - self.num_processed_frames, chunk_size + 3
|
self.num_frames - self.num_processed_frames, chunk_size
|
||||||
)
|
)
|
||||||
|
ret_length = update_length + self.pad_length
|
||||||
|
|
||||||
ret_features = self.features[
|
ret_features = self.features[
|
||||||
self.num_processed_frames : self.num_processed_frames # noqa
|
self.num_processed_frames : self.num_processed_frames # noqa
|
||||||
+ ret_chunk_size,
|
+ ret_length
|
||||||
:,
|
|
||||||
]
|
]
|
||||||
self.num_processed_frames += (
|
|
||||||
chunk_size
|
|
||||||
- 2 * self.params.subsampling_factor
|
|
||||||
- self.params.right_context * self.params.subsampling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.num_processed_frames >= self.features.size(0):
|
self.num_processed_frames += update_length
|
||||||
|
if self.num_processed_frames >= self.num_frames:
|
||||||
self._done = True
|
self._done = True
|
||||||
|
|
||||||
return ret_features, ret_chunk_size
|
return ret_features, ret_length
|
||||||
|
@ -341,20 +341,16 @@ def decode_one_chunk(
|
|||||||
states = []
|
states = []
|
||||||
|
|
||||||
rnnt_stream_list = []
|
rnnt_stream_list = []
|
||||||
processed_feature_lens = []
|
processed_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
# we plus 2 here because we will cut off one frame on each size of
|
|
||||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
|
||||||
# frames.
|
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
params.decode_chunk_size * params.subsampling_factor
|
||||||
* params.subsampling_factor
|
|
||||||
)
|
)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
processed_feature_lens.append(stream.feature_len)
|
processed_lens.append(stream.done_frames)
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||||
|
|
||||||
@ -388,16 +384,15 @@ def decode_one_chunk(
|
|||||||
torch.stack([x[1] for x in states], dim=2),
|
torch.stack([x[1] for x in states], dim=2),
|
||||||
]
|
]
|
||||||
|
|
||||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
# Note: states will be modified in streaming_forward.
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=states,
|
states=states,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
right_context=params.right_context,
|
right_context=params.right_context,
|
||||||
processed_lens=processed_feature_lens,
|
processed_lens=processed_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
@ -411,7 +406,7 @@ def decode_one_chunk(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||||
processed_lens = processed_feature_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search(
|
||||||
model, encoder_out, processed_lens, decoding_streams
|
model, encoder_out, processed_lens, decoding_streams
|
||||||
)
|
)
|
||||||
@ -423,7 +418,7 @@ def decode_one_chunk(
|
|||||||
finished_streams = []
|
finished_streams = []
|
||||||
for i in range(len(decode_streams)):
|
for i in range(len(decode_streams)):
|
||||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decode_streams[i].hyp = hyp_tokens[i]
|
decode_streams[i].hyp = hyp_tokens[i]
|
||||||
if decode_streams[i].done:
|
if decode_streams[i].done:
|
||||||
@ -469,7 +464,7 @@ def decode_dataset(
|
|||||||
opts.frame_opts.samp_freq = 16000
|
opts.frame_opts.samp_freq = 16000
|
||||||
opts.mel_opts.num_bins = 80
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
log_interval = 50
|
log_interval = 100
|
||||||
|
|
||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
@ -557,6 +552,9 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
|
# sort results so we can easily compare the difference between two
|
||||||
|
# recognition results
|
||||||
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
@ -347,20 +347,16 @@ def decode_one_chunk(
|
|||||||
states = []
|
states = []
|
||||||
|
|
||||||
rnnt_stream_list = []
|
rnnt_stream_list = []
|
||||||
processed_feature_lens = []
|
processed_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
# we plus 2 here because we will cut off one frame on each size of
|
|
||||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
|
||||||
# frames.
|
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
params.decode_chunk_size * params.subsampling_factor
|
||||||
* params.subsampling_factor
|
|
||||||
)
|
)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
processed_feature_lens.append(stream.feature_len)
|
processed_lens.append(stream.done_frames)
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||||
|
|
||||||
@ -369,6 +365,9 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
# if T is less than 7 there will be an error in time reduction layer,
|
# if T is less than 7 there will be an error in time reduction layer,
|
||||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||||
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||||
if features.size(1) < tail_length:
|
if features.size(1) < tail_length:
|
||||||
feature_lens += tail_length - features.size(1)
|
feature_lens += tail_length - features.size(1)
|
||||||
@ -390,7 +389,7 @@ def decode_one_chunk(
|
|||||||
torch.stack([x[0] for x in states], dim=2),
|
torch.stack([x[0] for x in states], dim=2),
|
||||||
torch.stack([x[1] for x in states], dim=2),
|
torch.stack([x[1] for x in states], dim=2),
|
||||||
]
|
]
|
||||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
@ -398,7 +397,7 @@ def decode_one_chunk(
|
|||||||
states=states,
|
states=states,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
right_context=params.right_context,
|
right_context=params.right_context,
|
||||||
processed_lens=processed_feature_lens,
|
processed_lens=processed_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -414,7 +413,7 @@ def decode_one_chunk(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||||
processed_lens = processed_feature_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search(
|
||||||
model, encoder_out, processed_lens, decoding_streams
|
model, encoder_out, processed_lens, decoding_streams
|
||||||
)
|
)
|
||||||
@ -426,7 +425,7 @@ def decode_one_chunk(
|
|||||||
finished_streams = []
|
finished_streams = []
|
||||||
for i in range(len(decode_streams)):
|
for i in range(len(decode_streams)):
|
||||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decode_streams[i].hyp = hyp_tokens[i]
|
decode_streams[i].hyp = hyp_tokens[i]
|
||||||
if decode_streams[i].done:
|
if decode_streams[i].done:
|
||||||
@ -561,7 +560,10 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
# sort results so we can easily compare the difference between two
|
||||||
|
# recognition results
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
@ -348,20 +348,16 @@ def decode_one_chunk(
|
|||||||
states = []
|
states = []
|
||||||
|
|
||||||
rnnt_stream_list = []
|
rnnt_stream_list = []
|
||||||
processed_feature_lens = []
|
processed_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
# we plus 2 here because we will cut off one frame on each size of
|
|
||||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
|
||||||
# frames.
|
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
params.decode_chunk_size * params.subsampling_factor
|
||||||
* params.subsampling_factor
|
|
||||||
)
|
)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
processed_feature_lens.append(stream.feature_len)
|
processed_lens.append(stream.done_frames)
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||||
|
|
||||||
@ -394,7 +390,7 @@ def decode_one_chunk(
|
|||||||
torch.stack([x[0] for x in states], dim=2),
|
torch.stack([x[0] for x in states], dim=2),
|
||||||
torch.stack([x[1] for x in states], dim=2),
|
torch.stack([x[1] for x in states], dim=2),
|
||||||
]
|
]
|
||||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
@ -402,7 +398,7 @@ def decode_one_chunk(
|
|||||||
states=states,
|
states=states,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
right_context=params.right_context,
|
right_context=params.right_context,
|
||||||
processed_lens=processed_feature_lens,
|
processed_lens=processed_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -418,7 +414,7 @@ def decode_one_chunk(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||||
processed_lens = processed_feature_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search(
|
||||||
model, encoder_out, processed_lens, decoding_streams
|
model, encoder_out, processed_lens, decoding_streams
|
||||||
)
|
)
|
||||||
@ -430,7 +426,7 @@ def decode_one_chunk(
|
|||||||
finished_streams = []
|
finished_streams = []
|
||||||
for i in range(len(decode_streams)):
|
for i in range(len(decode_streams)):
|
||||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decode_streams[i].hyp = hyp_tokens[i]
|
decode_streams[i].hyp = hyp_tokens[i]
|
||||||
if decode_streams[i].done:
|
if decode_streams[i].done:
|
||||||
@ -565,7 +561,8 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
@ -359,20 +359,16 @@ def decode_one_chunk(
|
|||||||
states = []
|
states = []
|
||||||
|
|
||||||
rnnt_stream_list = []
|
rnnt_stream_list = []
|
||||||
processed_feature_lens = []
|
processed_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
# we plus 2 here because we will cut off one frame on each size of
|
|
||||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
|
||||||
# frames.
|
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
params.decode_chunk_size * params.subsampling_factor
|
||||||
* params.subsampling_factor
|
|
||||||
)
|
)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
processed_feature_lens.append(stream.feature_len)
|
processed_lens.append(stream.done_frames)
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||||
|
|
||||||
@ -405,7 +401,7 @@ def decode_one_chunk(
|
|||||||
torch.stack([x[0] for x in states], dim=2),
|
torch.stack([x[0] for x in states], dim=2),
|
||||||
torch.stack([x[1] for x in states], dim=2),
|
torch.stack([x[1] for x in states], dim=2),
|
||||||
]
|
]
|
||||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
@ -413,7 +409,7 @@ def decode_one_chunk(
|
|||||||
states=states,
|
states=states,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
right_context=params.right_context,
|
right_context=params.right_context,
|
||||||
processed_lens=processed_feature_lens,
|
processed_lens=processed_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -429,7 +425,7 @@ def decode_one_chunk(
|
|||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||||
processed_lens = processed_feature_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search(
|
||||||
model, encoder_out, processed_lens, decoding_streams
|
model, encoder_out, processed_lens, decoding_streams
|
||||||
)
|
)
|
||||||
@ -441,7 +437,7 @@ def decode_one_chunk(
|
|||||||
finished_streams = []
|
finished_streams = []
|
||||||
for i in range(len(decode_streams)):
|
for i in range(len(decode_streams)):
|
||||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decode_streams[i].hyp = hyp_tokens[i]
|
decode_streams[i].hyp = hyp_tokens[i]
|
||||||
if decode_streams[i].done:
|
if decode_streams[i].done:
|
||||||
@ -576,7 +572,8 @@ def save_results(
|
|||||||
recog_path = (
|
recog_path = (
|
||||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
Loading…
x
Reference in New Issue
Block a user