mask the initial cache

This commit is contained in:
pkufool 2022-06-01 21:50:02 +08:00
parent 0325e3a04e
commit 9629be124d
3 changed files with 72 additions and 18 deletions

View File

@ -39,6 +39,8 @@ class DecodeStream(object):
if decoding_graph is not None:
assert device == decoding_graph.device
self.params = params
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
# how many frames are processed. (before subsampling).
@ -49,11 +51,12 @@ class DecodeStream(object):
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
self.feature_len: int = 0
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results.
self.feature_len: int = 0
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
@ -110,7 +113,10 @@ class DecodeStream(object):
+ ret_chunk_size,
:,
]
self.num_processed_frames += chunk_size
self.num_processed_frames += (
chunk_size
- self.params.right_context * self.params.subsampling_factor
)
if self.num_processed_frames >= self.features.size(0):
self._done = True

View File

@ -203,7 +203,9 @@ class Conformer(EncoderInterface):
warmup: float = 1.0,
chunk_size: int = 16,
left_context: int = 64,
right_context: int = 4,
simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -269,7 +271,20 @@ class Conformer(EncoderInterface):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context)
# src_key_padding_mask = make_pad_mask(lengths + left_context)
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x)
embed, pos_enc = self.encoder_pos(embed, left_context)
@ -282,8 +297,11 @@ class Conformer(EncoderInterface):
warmup=warmup,
states=states,
left_context=left_context,
right_context=right_context,
) # (T, B, F)
if right_context > 0:
x = x[0:-right_context, ...]
lengths -= right_context
else:
src_key_padding_mask = make_pad_mask(lengths)
@ -465,6 +483,7 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tensor:
"""
Pass the input through the encoder layer.
@ -504,7 +523,12 @@ class ConformerEncoderLayer(nn.Module):
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
# multi-headed self-attention module
src_att = self.self_attn(
@ -520,7 +544,7 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
conv, conv_cache = self.conv_module(src, states[1])
conv, conv_cache = self.conv_module(src, states[1], right_context)
states[1] = conv_cache
src = src + self.dropout(conv)
@ -604,6 +628,7 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -655,6 +680,7 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
left_context=left_context,
right_context=right_context,
)
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
@ -1306,6 +1332,7 @@ class ConvolutionModule(nn.Module):
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context=0,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
@ -1342,7 +1369,15 @@ class ConvolutionModule(nn.Module):
), "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)

View File

@ -205,6 +205,13 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=4,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
@ -355,8 +362,8 @@ def decode_one_chunk(
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_feature_lens.append(stream.feature_len)
if params.decoding_method == "fast_beam_search":
processed_feature_lens.append(stream.feature_len)
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
@ -364,15 +371,18 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
if features.size(1) < 7:
feature_lens += 7 - features.size(1)
tail_length = 7 + params.right_context * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0), 7 - features.size(1), features.size(2)
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
@ -382,6 +392,7 @@ def decode_one_chunk(
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
@ -389,7 +400,10 @@ def decode_one_chunk(
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_feature_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
@ -402,9 +416,6 @@ def decode_one_chunk(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
processed_feature_lens = torch.tensor(
processed_feature_lens, device=device
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
@ -418,8 +429,8 @@ def decode_one_chunk(
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].feature_len += encoder_out_lens[i]
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
@ -489,7 +500,8 @@ def decode_dataset(
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
decode_stream.set_features(fbank(samples.to(device)))
feature = fbank(samples.to(device))
decode_stream.set_features(feature)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
@ -541,14 +553,14 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
store_transcripts(filename=recog_path, texts=results)
store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
@ -602,6 +614,7 @@ def main():
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":