mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 16:44:20 +00:00
mask the initial cache
This commit is contained in:
parent
0325e3a04e
commit
9629be124d
@ -39,6 +39,8 @@ class DecodeStream(object):
|
||||
if decoding_graph is not None:
|
||||
assert device == decoding_graph.device
|
||||
|
||||
self.params = params
|
||||
|
||||
# It contains a 2-D tensors representing the feature frames.
|
||||
self.features: torch.Tensor = None
|
||||
# how many frames are processed. (before subsampling).
|
||||
@ -49,11 +51,12 @@ class DecodeStream(object):
|
||||
# The decoding result (partial or final) of current utterance.
|
||||
self.hyp: List = []
|
||||
|
||||
self.feature_len: int = 0
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
self.hyp = [params.blank_id] * params.context_size
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# feature_len is needed to get partial results.
|
||||
self.feature_len: int = 0
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||
k2.RnntDecodingStream(decoding_graph)
|
||||
@ -110,7 +113,10 @@ class DecodeStream(object):
|
||||
+ ret_chunk_size,
|
||||
:,
|
||||
]
|
||||
self.num_processed_frames += chunk_size
|
||||
self.num_processed_frames += (
|
||||
chunk_size
|
||||
- self.params.right_context * self.params.subsampling_factor
|
||||
)
|
||||
|
||||
if self.num_processed_frames >= self.features.size(0):
|
||||
self._done = True
|
||||
|
@ -203,7 +203,9 @@ class Conformer(EncoderInterface):
|
||||
warmup: float = 1.0,
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 64,
|
||||
right_context: int = 4,
|
||||
simulate_streaming: bool = False,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -269,7 +271,20 @@ class Conformer(EncoderInterface):
|
||||
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||
given {states[1].shape}."""
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths + left_context)
|
||||
# src_key_padding_mask = make_pad_mask(lengths + left_context)
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
assert processed_lens is not None
|
||||
|
||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||
x.size(0), left_context
|
||||
)
|
||||
processed_lens = processed_lens.view(x.size(0), 1)
|
||||
processed_mask = (processed_lens <= processed_mask).flip(1)
|
||||
|
||||
src_key_padding_mask = torch.cat(
|
||||
[processed_mask, src_key_padding_mask], dim=1
|
||||
)
|
||||
|
||||
embed = self.encoder_embed(x)
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
@ -282,8 +297,11 @@ class Conformer(EncoderInterface):
|
||||
warmup=warmup,
|
||||
states=states,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
) # (T, B, F)
|
||||
|
||||
if right_context > 0:
|
||||
x = x[0:-right_context, ...]
|
||||
lengths -= right_context
|
||||
else:
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
@ -465,6 +483,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -504,7 +523,12 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
states[0] = key[-left_context:, ...]
|
||||
if right_context > 0:
|
||||
states[0] = key[
|
||||
-(left_context + right_context) : -right_context, ... # noqa
|
||||
]
|
||||
else:
|
||||
states[0] = key[-left_context:, ...]
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
@ -520,7 +544,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
conv, conv_cache = self.conv_module(src, states[1])
|
||||
conv, conv_cache = self.conv_module(src, states[1], right_context)
|
||||
states[1] = conv_cache
|
||||
|
||||
src = src + self.dropout(conv)
|
||||
@ -604,6 +628,7 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -655,6 +680,7 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
)
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
@ -1306,6 +1332,7 @@ class ConvolutionModule(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
cache: Optional[Tensor] = None,
|
||||
right_context=0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
@ -1342,7 +1369,15 @@ class ConvolutionModule(nn.Module):
|
||||
), "Cache should be None in training time"
|
||||
assert cache.size(0) == self.lorder
|
||||
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||
if right_context > 0:
|
||||
cache = x.permute(2, 0, 1)[
|
||||
-(self.lorder + right_context) : ( # noqa
|
||||
-right_context
|
||||
),
|
||||
...,
|
||||
]
|
||||
else:
|
||||
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.deriv_balancer2(x)
|
||||
|
@ -205,6 +205,13 @@ def get_parser():
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--right-context",
|
||||
type=int,
|
||||
default=4,
|
||||
help="right context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
@ -355,8 +362,8 @@ def decode_one_chunk(
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_feature_lens.append(stream.feature_len)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
processed_feature_lens.append(stream.feature_len)
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
@ -364,15 +371,18 @@ def decode_one_chunk(
|
||||
|
||||
# if T is less than 7 there will be an error in time reduction layer,
|
||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||
if features.size(1) < 7:
|
||||
feature_lens += 7 - features.size(1)
|
||||
tail_length = 7 + params.right_context * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
feature_lens += tail_length - features.size(1)
|
||||
features = torch.cat(
|
||||
[
|
||||
features,
|
||||
torch.tensor(
|
||||
LOG_EPS, dtype=features.dtype, device=device
|
||||
).expand(
|
||||
features.size(0), 7 - features.size(1), features.size(2)
|
||||
features.size(0),
|
||||
tail_length - features.size(1),
|
||||
features.size(2),
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
@ -382,6 +392,7 @@ def decode_one_chunk(
|
||||
torch.stack([x[0] for x in states], dim=2),
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
||||
|
||||
# Note: states will be modified in streaming_forward.
|
||||
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
||||
@ -389,7 +400,10 @@ def decode_one_chunk(
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_feature_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
@ -402,9 +416,6 @@ def decode_one_chunk(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
processed_feature_lens = torch.tensor(
|
||||
processed_feature_lens, device=device
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
processed_lens = processed_feature_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
@ -418,8 +429,8 @@ def decode_one_chunk(
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
@ -489,7 +500,8 @@ def decode_dataset(
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
decode_stream.set_features(fbank(samples.to(device)))
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
@ -541,14 +553,14 @@ def decode_dataset(
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
store_transcripts(filename=recog_path, texts=sorted(results))
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
@ -602,6 +614,7 @@ def main():
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
params.suffix += f"-right-context-{params.right_context}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
|
Loading…
x
Reference in New Issue
Block a user