Minor fixes

This commit is contained in:
pkufool 2022-06-07 12:00:26 +08:00
parent 1c794e32b0
commit d7be9bd9c5
21 changed files with 246 additions and 409 deletions

View File

@ -73,7 +73,7 @@ Usage:
--avg 15 \ --avg 15 \
--simulate-streaming 1 \ --simulate-streaming 1 \
--causal-convolution 1 \ --causal-convolution 1 \
--right-chunk-size 16 \ --decode-chunk-size 16 \
--left-context 64 \ --left-context 64 \
--exp-dir ./pruned_transducer_stateless/exp \ --exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
@ -302,16 +302,9 @@ def get_parser():
test a streaming model. test a streaming model.
""", """,
) )
parser.add_argument( parser.add_argument(
"--causal-convolution", "--decode-chunk-size",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling)",
@ -379,7 +372,7 @@ def decode_one_batch(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
states=[], states=[],
chunk_size=params.right_chunk_size, chunk_size=params.decode_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True, simulate_streaming=True,
) )
@ -610,7 +603,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming: if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
@ -646,9 +639,8 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming: if params.simulate_streaming:
assert ( # Decoding in streaming requires causal convolution"
params.causal_convolution params.causal_convolution = True
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -37,6 +37,7 @@ class DecodeStream(object):
`get_init_state` in conformer.py `get_init_state` in conformer.py
decoding_graph: decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG. Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device: device:
The device to run this stream. The device to run this stream.
""" """
@ -49,7 +50,8 @@ class DecodeStream(object):
# It contains a 2-D tensors representing the feature frames. # It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None self.features: torch.Tensor = None
# how many frames are processed. (before subsampling). # how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0 self.num_processed_frames: int = 0
self._done: bool = False self._done: bool = False
# The transcript of current utterance. # The transcript of current utterance.
@ -57,6 +59,9 @@ class DecodeStream(object):
# The decoding result (partial or final) of current utterance. # The decoding result (partial or final) of current utterance.
self.hyp: List = [] self.hyp: List = []
# how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of
# encoder.streaming_forward
self.feature_len: int = 0 self.feature_len: int = 0
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -69,7 +74,7 @@ class DecodeStream(object):
else: else:
assert ( assert (
False False
), f"Decoding method :{params.decoding_method} do not support" ), f"Decoding method :{params.decoding_method} do not support."
@property @property
def done(self) -> bool: def done(self) -> bool:

View File

@ -149,14 +149,6 @@ def get_parser():
are streaming model, this should be True. are streaming model, this should be True.
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser return parser
@ -183,7 +175,7 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution params.causal_convolution = True
logging.info(params) logging.info(params)

View File

@ -66,8 +66,6 @@ class Transducer(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -138,31 +136,10 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(T0, device=px_grad.device).reshape(
1, 1, T0
) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, px_grad=px_grad,
@ -186,8 +163,7 @@ class Transducer(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum", reduction="sum",
) )
return (simple_loss, pruned_loss, sym_delay) return (simple_loss, pruned_loss)

View File

@ -20,9 +20,12 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--decode-chunk-size 8 \
--left-context 32 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 200 --num-decode-streams 1000
""" """
import argparse import argparse
@ -182,15 +185,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
@ -205,6 +199,13 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)", help="left context can be seen during decoding (in frames after subsampling)",
) )
parser.add_argument(
"--right-context",
type=int,
default=4,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument( parser.add_argument(
"--num-decode-streams", "--num-decode-streams",
type=int, type=int,
@ -343,14 +344,18 @@ def decode_one_chunk(
processed_feature_lens = [] processed_feature_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor (params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
) )
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_feature_lens.append(stream.feature_len)
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
processed_feature_lens.append(stream.feature_len)
rnnt_stream_list.append(stream.rnnt_decoding_stream) rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=device)
@ -358,15 +363,21 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
if features.size(1) < 7: # we plus 2 here because we will cut off one frame on each size of
feature_lens += 7 - features.size(1) # encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat( features = torch.cat(
[ [
features, features,
torch.tensor( torch.tensor(
LOG_EPS, dtype=features.dtype, device=device LOG_EPS, dtype=features.dtype, device=device
).expand( ).expand(
features.size(0), 7 - features.size(1), features.size(2) features.size(0),
tail_length - features.size(1),
features.size(2),
), ),
], ],
dim=1, dim=1,
@ -377,12 +388,16 @@ def decode_one_chunk(
torch.stack([x[1] for x in states], dim=2), torch.stack([x[1] for x in states], dim=2),
] ]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
# Note: states will be modified in streaming_forward. # Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward( encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features, x=features,
x_lens=feature_lens, x_lens=feature_lens,
states=states, states=states,
left_context=params.left_context, left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_feature_lens,
) )
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
@ -395,9 +410,6 @@ def decode_one_chunk(
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
processed_feature_lens = torch.tensor(
processed_feature_lens, device=device
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config) decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search(
@ -411,8 +423,8 @@ def decode_one_chunk(
finished_streams = [] finished_streams = []
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]] decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decode_streams[i].feature_len += encoder_out_lens[i]
decode_streams[i].hyp = hyp_tokens[i] decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done: if decode_streams[i].done:
finished_streams.append(i) finished_streams.append(i)
@ -457,12 +469,14 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000 opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80 opts.mel_opts.num_bins = 80
log_interval = 300 log_interval = 50
decode_results = [] decode_results = []
# Contain decode streams currently running. # Contain decode streams currently running.
decode_streams = [] decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device) initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts): for num, cut in enumerate(cuts):
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
@ -536,7 +550,7 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]], results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
): ):
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
@ -597,6 +611,7 @@ def main():
# for streaming # for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search # for fast_beam_search
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -620,10 +635,7 @@ def main():
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
params.causal_convolution = True
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -37,7 +37,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless/exp \ --exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--dynamic-chunk-training 1 \ --dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \ --short-chunk-size 25 \
--num-left-chunks 4 \ --num-left-chunks 4 \
--max-duration 300 --max-duration 300
@ -244,15 +243,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--short-chunk-size", "--short-chunk-size",
type=int, type=int,
@ -269,25 +259,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.", help="How many left context can be seen in chunks when calculating attention.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -554,17 +525,14 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
sym_delay = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, sym_delay = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss loss = params.simple_loss_scale * simple_loss + pruned_loss
@ -582,9 +550,6 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if sym_delay is not None:
info["sym_delay"] = sym_delay.detatch().cpu().item()
return loss, info return loss, info
@ -839,9 +804,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training: if params.dynamic_chunk_training:
assert ( # dynamic_chunk_training requires causal convolution
params.causal_convolution params.causal_convolution = True
), "dynamic_chunk_training requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -279,19 +279,26 @@ class Conformer(EncoderInterface):
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
left_context: left_context:
How many old frames the attention can see in current chunk, it MUST How many previous frames the attention can see in current chunk.
be equal to left_context in decode_states. Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
simulate_streaming: simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time right context). The whole sequence is supposed to be send at a time
When using simulate_streaming. When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
Returns: Returns:
Return a tuple containing 2 tensors: Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim) - logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number - logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `logits` before padding.
- decode_states, the updated DecodeStates including the information - decode_states, the updated states including the information
of current chunk. of current chunk.
""" """
@ -321,8 +328,6 @@ class Conformer(EncoderInterface):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}.""" given {states[1].shape}."""
# src_key_padding_mask = make_pad_mask(lengths + left_context)
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
@ -341,6 +346,8 @@ class Conformer(EncoderInterface):
embed = self.encoder_embed(x) embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :] embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
@ -359,7 +366,8 @@ class Conformer(EncoderInterface):
x = x[0:-right_context, ...] x = x[0:-right_context, ...]
lengths -= right_context lengths -= right_context
else: else:
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
@ -558,9 +566,14 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding. left_context:
this is used only in real streaming decoding, in other circumstances, How many previous frames the attention can see in current chunk.
it MUST be 0. Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
@ -708,10 +721,14 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently. bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding. left_context:
this is used only in real streaming decoding, in other circumstances, How many previous frames the attention can see in current chunk.
it MUST be 0. Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*(S+left_context)-1, E).
@ -1273,9 +1290,17 @@ class RelPositionMultiheadAttention(nn.Module):
and attn_mask.dtype == torch.bool and attn_mask.dtype == torch.bool
and key_padding_mask is not None and key_padding_mask is not None
): ):
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( if attn_mask.size(0) != 1:
1 attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
).unsqueeze(2) combined_mask = attn_mask | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(
0
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len bsz, num_heads, tgt_len, src_len
) )
@ -1404,6 +1429,10 @@ class ConvolutionModule(nn.Module):
x: Input tensor (#time, batch, channels). x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming cache: The cache of depthwise_conv, only used in real streaming
decoding. decoding.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Returns: Returns:
If cache is None return the output tensor (#time, batch, channels). If cache is None return the output tensor (#time, batch, channels).

View File

@ -59,8 +59,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--simulate-streaming 1 \ --simulate-streaming 1 \
--causal-convolution 1 \ --decode-chunk-size 16 \
--right-chunk-size 16 \
--left-context 64 \ --left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 600 \
@ -257,16 +256,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--causal-convolution", "--decode-chunk-size",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling)",
@ -335,11 +325,11 @@ def decode_one_batch(
) )
if params.simulate_streaming: if params.simulate_streaming:
encoder_out, encoder_out_lens = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
states=[], states=[],
chunk_size=params.right_chunk_size, chunk_size=params.decode_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True, simulate_streaming=True,
) )
@ -561,7 +551,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming: if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
@ -594,9 +584,8 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming: if params.simulate_streaming:
assert ( # Decoding in streaming requires causal convolution
params.causal_convolution params.causal_convolution = True
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -164,14 +164,6 @@ def get_parser():
are streaming model, this should be True. are streaming model, this should be True.
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser return parser
@ -197,7 +189,7 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution params.causal_convolution = True
logging.info(params) logging.info(params)

View File

@ -78,8 +78,6 @@ class Transducer(nn.Module):
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
warmup: float = 1.0, warmup: float = 1.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -157,31 +155,10 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(T0, device=px_grad.device).reshape(
1, 1, T0
) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, px_grad=px_grad,
@ -210,9 +187,8 @@ class Transducer(nn.Module):
symbols=y_padded, symbols=y_padded,
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
delay_penalty=delay_penalty,
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
) )
return (simple_loss, pruned_loss, sym_delay) return (simple_loss, pruned_loss)

View File

@ -20,9 +20,12 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 200 --num-decode-streams 1000
""" """
import argparse import argparse
@ -182,15 +185,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
@ -356,6 +350,9 @@ def decode_one_chunk(
processed_feature_lens = [] processed_feature_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) (params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor * params.subsampling_factor
@ -372,7 +369,7 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) feature_lens += tail_length - features.size(1)
features = torch.cat( features = torch.cat(
@ -642,10 +639,8 @@ def main():
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
# Decoding in streaming requires causal convolution
assert ( params.causal_convolution = True
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -48,11 +48,9 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless/exp \ --exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--dynamic-chunk-training 1 \ --dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \ --short-chunk-size 25 \
--num-left-chunks 4 \ --num-left-chunks 4 \
--max-duration 300 --max-duration 300
""" """
@ -285,15 +283,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--short-chunk-size", "--short-chunk-size",
type=int, type=int,
@ -310,25 +299,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.", help="How many left context can be seen in chunks when calculating attention.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -611,7 +581,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, sym_delay = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -619,8 +589,6 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -650,9 +618,6 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info return loss, info
@ -882,13 +847,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training: if params.dynamic_chunk_training:
assert ( # dynamic_chunk_training requires causal convolution
params.causal_convolution params.causal_convolution = True
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params) logging.info(params)

View File

@ -266,16 +266,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--causal-convolution", "--decode-chunk-size",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling)",
@ -348,7 +339,7 @@ def decode_one_batch(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
states=[], states=[],
chunk_size=params.right_chunk_size, chunk_size=params.decode_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True, simulate_streaming=True,
) )
@ -596,7 +587,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming: if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -635,9 +626,8 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming: if params.simulate_streaming:
assert ( # Decoding in streaming requires causal convolution
params.causal_convolution params.causal_convolution = True
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -165,14 +165,6 @@ def get_parser():
are streaming model, this should be True. are streaming model, this should be True.
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser return parser
@ -198,7 +190,7 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.streaming_model: if params.streaming_model:
assert params.causal_convolution params.causal_convolution = True
logging.info(params) logging.info(params)

View File

@ -20,9 +20,12 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 200 --num-decode-streams 1000
""" """
import argparse import argparse
@ -183,15 +186,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
@ -357,6 +351,9 @@ def decode_one_chunk(
processed_feature_lens = [] processed_feature_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) (params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor * params.subsampling_factor
@ -373,7 +370,10 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor # we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) feature_lens += tail_length - features.size(1)
features = torch.cat( features = torch.cat(
@ -481,7 +481,9 @@ def decode_dataset(
decode_results = [] decode_results = []
# Contain decode streams currently running. # Contain decode streams currently running.
decode_streams = [] decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device) initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts): for num, cut in enumerate(cuts):
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
@ -641,10 +643,8 @@ def main():
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
# Decoding in streaming requires causal convolution
assert ( params.causal_convolution = True
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -295,15 +295,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--short-chunk-size", "--short-chunk-size",
type=int, type=int,
@ -320,25 +311,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.", help="How many left context can be seen in chunks when calculating attention.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -963,13 +935,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training: if params.dynamic_chunk_training:
assert ( # dynamic_chunk_training requires causal convolution
params.causal_convolution params.causal_convolution = True
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params) logging.info(params)

View File

@ -60,8 +60,7 @@ Usage:
--epoch 30 \ --epoch 30 \
--avg 15 \ --avg 15 \
--simulate-streaming 1 \ --simulate-streaming 1 \
--causal-convolution 1 \ --decode-chunk-size 16 \
--right-chunk-size 16 \
--left-context 64 \ --left-context 64 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \ --max-duration 600 \
@ -269,16 +268,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--causal-convolution", "--decode-chunk-size",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int, type=int,
default=16, default=16,
help="The chunk size for decoding (in frames after subsampling)", help="The chunk size for decoding (in frames after subsampling)",
@ -351,7 +341,7 @@ def decode_one_batch(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
states=[], states=[],
chunk_size=params.right_chunk_size, chunk_size=params.decode_chunk_size,
left_context=params.left_context, left_context=params.left_context,
simulate_streaming=True, simulate_streaming=True,
) )
@ -573,7 +563,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming: if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
@ -609,9 +599,8 @@ def main():
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming: if params.simulate_streaming:
assert ( # Decoding in streaming requires causal convolution
params.causal_convolution params.causal_convolution = True
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -20,6 +20,9 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \ ./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \ --decoding_method greedy_search \
--num-decode-streams 200 --num-decode-streams 200
@ -194,15 +197,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--decode-chunk-size", "--decode-chunk-size",
type=int, type=int,
@ -368,6 +362,9 @@ def decode_one_chunk(
processed_feature_lens = [] processed_feature_lens = []
for stream in decode_streams: for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames( feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context) (params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor * params.subsampling_factor
@ -384,7 +381,10 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer, # if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2 # because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor # we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length: if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1) feature_lens += tail_length - features.size(1)
features = torch.cat( features = torch.cat(
@ -492,7 +492,9 @@ def decode_dataset(
decode_results = [] decode_results = []
# Contain decode streams currently running. # Contain decode streams currently running.
decode_streams = [] decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device) initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts): for num, cut in enumerate(cuts):
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
decode_stream = DecodeStream( decode_stream = DecodeStream(
@ -655,10 +657,8 @@ def main():
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
# Decoding in streaming requires causal convolution
assert ( params.causal_convolution = True
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)

View File

@ -49,7 +49,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless4/exp \ --exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \ --full-libri 1 \
--dynamic-chunk-training 1 \ --dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \ --short-chunk-size 25 \
--num-left-chunks 4 \ --num-left-chunks 4 \
--max-duration 300 --max-duration 300
@ -302,15 +301,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--short-chunk-size", "--short-chunk-size",
type=int, type=int,
@ -327,25 +317,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.", help="How many left context can be seen in chunks when calculating attention.",
) )
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -640,7 +611,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, sym_delay = model( simple_loss, pruned_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -648,8 +619,6 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -679,9 +648,6 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info return loss, info
@ -922,13 +888,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training: if params.dynamic_chunk_training:
assert ( # dynamic_chunk_training requires causal convolution
params.causal_convolution params.causal_convolution = True
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params) logging.info(params)

View File

@ -248,7 +248,9 @@ class Conformer(Transformer):
states: List[torch.Tensor], states: List[torch.Tensor],
chunk_size: int = 16, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
right_context: int = 0,
simulate_streaming: bool = False, simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
""" """
Args: Args:
@ -268,13 +270,20 @@ class Conformer(Transformer):
The chunk size for decoding, this will be used to simulate streaming The chunk size for decoding, this will be used to simulate streaming
decoding using masking. decoding using masking.
left_context: left_context:
How many old frames the attention can see in current chunk, it MUST How many previous frames the attention can see in current chunk.
be equal to left_context in decode_states. Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
simulate_streaming: simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time right context). The whole sequence is supposed to be send at a time
When using simulate_streaming. When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
Returns: Returns:
Return a tuple containing 2 tensors: Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim) - logits, its shape is (batch_size, output_seq_len, output_dim)
@ -310,9 +319,27 @@ class Conformer(Transformer):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}.""" given {states[1].shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context) lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x) embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context) embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
@ -322,6 +349,7 @@ class Conformer(Transformer):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
states=states, states=states,
left_context=left_context, left_context=left_context,
right_context=right_context,
) # (T, B, F) ) # (T, B, F)
else: else:
src_key_padding_mask = make_pad_mask(lengths) src_key_padding_mask = make_pad_mask(lengths)
@ -512,6 +540,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0, left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]: ) -> Tuple[Tensor, List[Tensor]]:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -528,9 +557,14 @@ class ConformerEncoderLayer(nn.Module):
Note: states will be modified in this function. Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding. left_context:
this is used only in real streaming decoding, in other circumstances, How many previous frames the attention can see in current chunk.
it MUST be 0. Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*(S+left_context)-1, E).
@ -562,7 +596,12 @@ class ConformerEncoderLayer(nn.Module):
# separately) if needed. # separately) if needed.
key = torch.cat([states[0], src], dim=0) key = torch.cat([states[0], src], dim=0)
val = key val = key
states[0] = key[-left_context:, ...] if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
src_att = self.self_attn( src_att = self.self_attn(
src, src,
@ -582,7 +621,9 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
src = self.norm_conv(src) src = self.norm_conv(src)
src, conv_cache = self.conv_module(src, states[1]) src, conv_cache = self.conv_module(
src, states[1], right_context=right_context
)
states[1] = conv_cache states[1] = conv_cache
src = residual + self.dropout(src) src = residual + self.dropout(src)
@ -669,6 +710,7 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0, left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]: ) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -684,9 +726,14 @@ class ConformerEncoder(nn.Module):
Note: states will be modified in this function. Note: states will be modified in this function.
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding. left_context:
this is used only in real streaming decoding, in other circumstances, How many previous frames the attention can see in current chunk.
it MUST be 0. Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E). pos_emb: (N, 2*(S+left_context)-1, E).
@ -707,6 +754,7 @@ class ConformerEncoder(nn.Module):
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
left_context=left_context, left_context=left_context,
right_context=right_context,
) )
states[0][layer_index] = cache[0] states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1] states[1][layer_index] = cache[1]
@ -1329,7 +1377,10 @@ class ConvolutionModule(nn.Module):
self.activation = Swish() self.activation = Swish()
def forward( def forward(
self, x: Tensor, cache: Optional[Tensor] = None self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Compute convolution module. """Compute convolution module.
@ -1359,7 +1410,15 @@ class ConvolutionModule(nn.Module):
), "Cache should be None in training time" ), "Cache should be None in training time"
assert cache.size(0) == self.lorder assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2) x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
# x is (batch, channels, time) # x is (batch, channels, time)

View File

@ -535,11 +535,8 @@ class MetricsTracker(collections.defaultdict):
ans = [] ans = []
for k, v in self.items(): for k, v in self.items():
if k != "frames": if k != "frames":
if k != "sym_delay": norm_value = float(v) / num_frames
norm_value = float(v) / num_frames ans.append((k, norm_value))
ans.append((k, norm_value))
else:
ans.append((k, float(v)))
return ans return ans
def reduce(self, device): def reduce(self, device):