mask the initial cache

This commit is contained in:
pkufool 2022-06-01 21:50:02 +08:00
parent 0325e3a04e
commit 9629be124d
3 changed files with 72 additions and 18 deletions

View File

@ -39,6 +39,8 @@ class DecodeStream(object):
if decoding_graph is not None: if decoding_graph is not None:
assert device == decoding_graph.device assert device == decoding_graph.device
self.params = params
# It contains a 2-D tensors representing the feature frames. # It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None self.features: torch.Tensor = None
# how many frames are processed. (before subsampling). # how many frames are processed. (before subsampling).
@ -49,11 +51,12 @@ class DecodeStream(object):
# The decoding result (partial or final) of current utterance. # The decoding result (partial or final) of current utterance.
self.hyp: List = [] self.hyp: List = []
self.feature_len: int = 0
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size self.hyp = [params.blank_id] * params.context_size
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
# feature_len is needed to get partial results. # feature_len is needed to get partial results.
self.feature_len: int = 0
# The rnnt_decoding_stream for fast_beam_search. # The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = ( self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph) k2.RnntDecodingStream(decoding_graph)
@ -110,7 +113,10 @@ class DecodeStream(object):
+ ret_chunk_size, + ret_chunk_size,
:, :,
] ]
self.num_processed_frames += chunk_size self.num_processed_frames += (
chunk_size
- self.params.right_context * self.params.subsampling_factor
)
if self.num_processed_frames >= self.features.size(0): if self.num_processed_frames >= self.features.size(0):
self._done = True self._done = True

View File

@ -203,7 +203,9 @@ class Conformer(EncoderInterface):
warmup: float = 1.0, warmup: float = 1.0,
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
right_context: int = 4,
simulate_streaming: bool = False, simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -269,7 +271,20 @@ class Conformer(EncoderInterface):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}.""" given {states[1].shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context) # src_key_padding_mask = make_pad_mask(lengths + left_context)
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x) embed = self.encoder_embed(x)
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
@ -282,8 +297,11 @@ class Conformer(EncoderInterface):
warmup=warmup, warmup=warmup,
states=states, states=states,
left_context=left_context, left_context=left_context,
right_context=right_context,
) # (T, B, F) ) # (T, B, F)
if right_context > 0:
x = x[0:-right_context, ...]
lengths -= right_context
else: else:
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
@ -465,6 +483,7 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
left_context: int = 0, left_context: int = 0,
right_context: int = 0,
) -> Tensor: ) -> Tensor:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -504,7 +523,12 @@ class ConformerEncoderLayer(nn.Module):
key = torch.cat([states[0], src], dim=0) key = torch.cat([states[0], src], dim=0)
val = key val = key
states[0] = key[-left_context:, ...] if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
# multi-headed self-attention module # multi-headed self-attention module
src_att = self.self_attn( src_att = self.self_attn(
@ -520,7 +544,7 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att) src = src + self.dropout(src_att)
# convolution module # convolution module
conv, conv_cache = self.conv_module(src, states[1]) conv, conv_cache = self.conv_module(src, states[1], right_context)
states[1] = conv_cache states[1] = conv_cache
src = src + self.dropout(conv) src = src + self.dropout(conv)
@ -604,6 +628,7 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0, warmup: float = 1.0,
left_context: int = 0, left_context: int = 0,
right_context: int = 0,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -655,6 +680,7 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
left_context=left_context, left_context=left_context,
right_context=right_context,
) )
states[0][layer_index] = cache[0] states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1] states[1][layer_index] = cache[1]
@ -1306,6 +1332,7 @@ class ConvolutionModule(nn.Module):
self, self,
x: Tensor, x: Tensor,
cache: Optional[Tensor] = None, cache: Optional[Tensor] = None,
right_context=0,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Compute convolution module. """Compute convolution module.
@ -1342,7 +1369,15 @@ class ConvolutionModule(nn.Module):
), "Cache should be None in training time" ), "Cache should be None in training time"
assert cache.size(0) == self.lorder assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2) x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.deriv_balancer2(x) x = self.deriv_balancer2(x)

View File

@ -205,6 +205,13 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--right-context",
type=int,
default=4,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument( parser.add_argument(
"--num-decode-streams", "--num-decode-streams",
type=int, type=int,
@ -355,8 +362,8 @@ def decode_one_chunk(
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_feature_lens.append(stream.feature_len)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
processed_feature_lens.append(stream.feature_len)
rnnt_stream_list.append(stream.rnnt_decoding_stream) rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
@ -364,15 +371,18 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
if features.size(1) < 7: tail_length = 7 + params.right_context * params.subsampling_factor
feature_lens += 7 - features.size(1) if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat( features = torch.cat(
[ [
features, features,
torch.tensor( torch.tensor(
LOG_EPS, dtype=features.dtype, device=device LOG_EPS, dtype=features.dtype, device=device
).expand( ).expand(
features.size(0), 7 - features.size(1), features.size(2) features.size(0),
tail_length - features.size(1),
features.size(2),
), ),
], ],
dim=1, dim=1,
@ -382,6 +392,7 @@ def decode_one_chunk(
torch.stack([x[0] for x in states], dim=2), torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2),
] ]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
# Note: states will be modified in streaming_forward. # Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens = model.encoder.streaming_forward( encoder_out, encoder_out_lens = model.encoder.streaming_forward(
@ -389,7 +400,10 @@ def decode_one_chunk(
x_lens=feature_lens, x_lens=feature_lens,
states=states, states=states,
left_context=params.left_context, left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_feature_lens,
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -402,9 +416,6 @@ def decode_one_chunk(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
processed_feature_lens = torch.tensor(
processed_feature_lens, device=device
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search(
@ -418,8 +429,8 @@ def decode_one_chunk(
finished_streams = [] finished_streams = []
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decode_streams[i].feature_len += encoder_out_lens[i]
decode_streams[i].hyp = hyp_tokens[i] decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -489,7 +500,8 @@ def decode_dataset(
samples = torch.from_numpy(audio).squeeze(0) samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts) fbank = Fbank(opts)
decode_stream.set_features(fbank(samples.to(device))) feature = fbank(samples.to(device))
decode_stream.set_features(feature)
decode_stream.ground_truth = cut.supervisions[0].text decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream) decode_streams.append(decode_stream)
@ -541,14 +553,14 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=sorted(results))
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
@ -602,6 +614,7 @@ def main():
# for streaming # for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search # for fast_beam_search
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":