Minor fixes

This commit is contained in:
pkufool 2022-06-07 12:00:26 +08:00
parent 1c794e32b0
commit d7be9bd9c5
21 changed files with 246 additions and 409 deletions

View File

@ -73,7 +73,7 @@ Usage:
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--right-chunk-size 16 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
@ -302,16 +302,9 @@ def get_parser():
test a streaming model.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
@ -379,7 +372,7 @@ def decode_one_batch(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
@ -610,7 +603,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
@ -646,9 +639,8 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution"
params.causal_convolution = True
logging.info(params)

View File

@ -37,6 +37,7 @@ class DecodeStream(object):
`get_init_state` in conformer.py
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
Used only when decoding_method is fast_beam_search.
device:
The device to run this stream.
"""
@ -49,7 +50,8 @@ class DecodeStream(object):
# It contains a 2-D tensors representing the feature frames.
self.features: torch.Tensor = None
# how many frames are processed. (before subsampling).
# how many frames have been processed. (before subsampling).
# we only modify this value in `func:get_feature_frames`.
self.num_processed_frames: int = 0
self._done: bool = False
# The transcript of current utterance.
@ -57,6 +59,9 @@ class DecodeStream(object):
# The decoding result (partial or final) of current utterance.
self.hyp: List = []
# how many frames have been processed, after subsampling (i.e. a
# cumulative sum of the second return value of
# encoder.streaming_forward
self.feature_len: int = 0
if params.decoding_method == "greedy_search":
@ -69,7 +74,7 @@ class DecodeStream(object):
else:
assert (
False
), f"Decoding method :{params.decoding_method} do not support"
), f"Decoding method :{params.decoding_method} do not support."
@property
def done(self) -> bool:

View File

@ -149,14 +149,6 @@ def get_parser():
are streaming model, this should be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser
@ -183,7 +175,7 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
params.causal_convolution = True
logging.info(params)

View File

@ -66,8 +66,6 @@ class Transducer(nn.Module):
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor:
"""
Args:
@ -138,31 +136,10 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum",
return_grad=True,
)
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(T0, device=px_grad.device).reshape(
1, 1, T0
) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
@ -186,8 +163,7 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum",
)
return (simple_loss, pruned_loss, sym_delay)
return (simple_loss, pruned_loss)

View File

@ -20,9 +20,12 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-size 8 \
--left-context 32 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
--num-decode-streams 1000
"""
import argparse
@ -182,15 +185,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
@ -205,6 +199,13 @@ def get_parser():
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=4,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
@ -343,14 +344,18 @@ def decode_one_chunk(
processed_feature_lens = []
for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_feature_lens.append(stream.feature_len)
if params.decoding_method == "fast_beam_search":
processed_feature_lens.append(stream.feature_len)
rnnt_stream_list.append(stream.rnnt_decoding_stream)
feature_lens = torch.tensor(feature_lens, device=device)
@ -358,15 +363,21 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
if features.size(1) < 7:
feature_lens += 7 - features.size(1)
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
[
features,
torch.tensor(
LOG_EPS, dtype=features.dtype, device=device
).expand(
features.size(0), 7 - features.size(1), features.size(2)
features.size(0),
tail_length - features.size(1),
features.size(2),
),
],
dim=1,
@ -377,12 +388,16 @@ def decode_one_chunk(
torch.stack([x[1] for x in states], dim=2),
]
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
# Note: states will be modified in streaming_forward.
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_feature_lens,
)
if params.decoding_method == "greedy_search":
@ -395,9 +410,6 @@ def decode_one_chunk(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
processed_feature_lens = torch.tensor(
processed_feature_lens, device=device
)
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
processed_lens = processed_feature_lens + encoder_out_lens
hyp_tokens = fast_beam_search(
@ -411,8 +423,8 @@ def decode_one_chunk(
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].feature_len += encoder_out_lens[i]
if params.decoding_method == "fast_beam_search":
decode_streams[i].feature_len += encoder_out_lens[i]
decode_streams[i].hyp = hyp_tokens[i]
if decode_streams[i].done:
finished_streams.append(i)
@ -457,12 +469,14 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 300
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
@ -536,7 +550,7 @@ def decode_dataset(
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
@ -597,6 +611,7 @@ def main():
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
@ -620,10 +635,7 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
params.causal_convolution = True
logging.info(params)

View File

@ -37,7 +37,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
@ -244,15 +243,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
@ -269,25 +259,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser
@ -554,17 +525,14 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
sym_delay = None
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, sym_delay = model(
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
)
loss = params.simple_loss_scale * simple_loss + pruned_loss
@ -582,9 +550,6 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if sym_delay is not None:
info["sym_delay"] = sym_delay.detatch().cpu().item()
return loss, info
@ -839,9 +804,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
# dynamic_chunk_training requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -279,19 +279,26 @@ class Conformer(EncoderInterface):
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
left_context:
How many old frames the attention can see in current chunk, it MUST
be equal to left_context in decode_states.
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated DecodeStates including the information
- decode_states, the updated states including the information
of current chunk.
"""
@ -321,8 +328,6 @@ class Conformer(EncoderInterface):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
# src_key_padding_mask = make_pad_mask(lengths + left_context)
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
@ -341,6 +346,8 @@ class Conformer(EncoderInterface):
embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
@ -359,7 +366,8 @@ class Conformer(EncoderInterface):
x = x[0:-right_context, ...]
lengths -= right_context
else:
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
@ -558,9 +566,14 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
@ -708,10 +721,14 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
@ -1273,9 +1290,17 @@ class RelPositionMultiheadAttention(nn.Module):
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(
0
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
@ -1404,6 +1429,10 @@ class ConvolutionModule(nn.Module):
x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming
decoding.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Returns:
If cache is None return the output tensor (#time, batch, channels).

View File

@ -59,8 +59,7 @@ Usage:
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--right-chunk-size 16 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
@ -257,16 +256,7 @@ def get_parser():
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
@ -335,11 +325,11 @@ def decode_one_batch(
)
if params.simulate_streaming:
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
@ -561,7 +551,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
@ -594,9 +584,8 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -164,14 +164,6 @@ def get_parser():
are streaming model, this should be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser
@ -197,7 +189,7 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
params.causal_convolution = True
logging.info(params)

View File

@ -78,8 +78,6 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor:
"""
Args:
@ -157,31 +155,10 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum",
return_grad=True,
)
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(T0, device=px_grad.device).reshape(
1, 1, T0
) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
@ -210,9 +187,8 @@ class Transducer(nn.Module):
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
delay_penalty=delay_penalty,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss, sym_delay)
return (simple_loss, pruned_loss)

View File

@ -20,9 +20,12 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
--num-decode-streams 1000
"""
import argparse
@ -182,15 +185,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
@ -356,6 +350,9 @@ def decode_one_chunk(
processed_feature_lens = []
for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
@ -372,7 +369,7 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
@ -642,10 +639,8 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -48,11 +48,9 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
"""
@ -285,15 +283,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
@ -310,25 +299,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser
@ -611,7 +581,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, sym_delay = model(
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -619,8 +589,6 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -650,9 +618,6 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info
@ -882,13 +847,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
# dynamic_chunk_training requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -266,16 +266,7 @@ def get_parser():
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
@ -348,7 +339,7 @@ def decode_one_batch(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
@ -596,7 +587,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if params.decoding_method == "fast_beam_search":
@ -635,9 +626,8 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -165,14 +165,6 @@ def get_parser():
are streaming model, this should be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser
@ -198,7 +190,7 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
params.causal_convolution = True
logging.info(params)

View File

@ -20,9 +20,12 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
--num-decode-streams 1000
"""
import argparse
@ -183,15 +186,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
@ -357,6 +351,9 @@ def decode_one_chunk(
processed_feature_lens = []
for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
@ -373,7 +370,10 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
@ -481,7 +481,9 @@ def decode_dataset(
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
@ -641,10 +643,8 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -295,15 +295,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
@ -320,25 +311,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser
@ -963,13 +935,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
# dynamic_chunk_training requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -60,8 +60,7 @@ Usage:
--epoch 30 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--right-chunk-size 16 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
@ -269,16 +268,7 @@ def get_parser():
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
@ -351,7 +341,7 @@ def decode_one_batch(
x=feature,
x_lens=feature_lens,
states=[],
chunk_size=params.right_chunk_size,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
@ -573,7 +563,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
@ -609,9 +599,8 @@ def main():
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -20,6 +20,9 @@ Usage:
./pruned_transducer_stateless2/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 2 \
--exp-dir ./pruned_transducer_stateless2/exp \
--decoding_method greedy_search \
--num-decode-streams 200
@ -194,15 +197,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=True,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
@ -368,6 +362,9 @@ def decode_one_chunk(
processed_feature_lens = []
for stream in decode_streams:
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
feat, feat_len = stream.get_feature_frames(
(params.decode_chunk_size + 2 + params.right_context)
* params.subsampling_factor
@ -384,7 +381,10 @@ def decode_one_chunk(
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
tail_length = 15 + params.right_context * params.subsampling_factor
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
feature_lens += tail_length - features.size(1)
features = torch.cat(
@ -492,7 +492,9 @@ def decode_dataset(
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
@ -655,10 +657,8 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -49,7 +49,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
@ -302,15 +301,6 @@ def get_parser():
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
@ -327,25 +317,6 @@ def get_parser():
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser
@ -640,7 +611,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, sym_delay = model(
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -648,8 +619,6 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -679,9 +648,6 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info
@ -922,13 +888,8 @@ def run(rank, world_size, args):
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
# dynamic_chunk_training requires causal convolution
params.causal_convolution = True
logging.info(params)

View File

@ -248,7 +248,9 @@ class Conformer(Transformer):
states: List[torch.Tensor],
chunk_size: int = 16,
left_context: int = 64,
right_context: int = 0,
simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
@ -268,13 +270,20 @@ class Conformer(Transformer):
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
left_context:
How many old frames the attention can see in current chunk, it MUST
be equal to left_context in decode_states.
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
@ -310,9 +319,27 @@ class Conformer(Transformer):
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
src_key_padding_mask = make_pad_mask(lengths + left_context)
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
assert processed_lens is not None
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
@ -322,6 +349,7 @@ class Conformer(Transformer):
src_key_padding_mask=src_key_padding_mask,
states=states,
left_context=left_context,
right_context=right_context,
) # (T, B, F)
else:
src_key_padding_mask = make_pad_mask(lengths)
@ -512,6 +540,7 @@ class ConformerEncoderLayer(nn.Module):
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
"""
Pass the input through the encoder layer.
@ -528,9 +557,14 @@ class ConformerEncoderLayer(nn.Module):
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
@ -562,7 +596,12 @@ class ConformerEncoderLayer(nn.Module):
# separately) if needed.
key = torch.cat([states[0], src], dim=0)
val = key
states[0] = key[-left_context:, ...]
if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
src_att = self.self_attn(
src,
@ -582,7 +621,9 @@ class ConformerEncoderLayer(nn.Module):
if self.normalize_before:
src = self.norm_conv(src)
src, conv_cache = self.conv_module(src, states[1])
src, conv_cache = self.conv_module(
src, states[1], right_context=right_context
)
states[1] = conv_cache
src = residual + self.dropout(src)
@ -669,6 +710,7 @@ class ConformerEncoder(nn.Module):
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn.
@ -684,9 +726,14 @@ class ConformerEncoder(nn.Module):
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
left_context: left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
@ -707,6 +754,7 @@ class ConformerEncoder(nn.Module):
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
left_context=left_context,
right_context=right_context,
)
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
@ -1329,7 +1377,10 @@ class ConvolutionModule(nn.Module):
self.activation = Swish()
def forward(
self, x: Tensor, cache: Optional[Tensor] = None
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
@ -1359,7 +1410,15 @@ class ConvolutionModule(nn.Module):
), "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x)
# x is (batch, channels, time)

View File

@ -535,11 +535,8 @@ class MetricsTracker(collections.defaultdict):
ans = []
for k, v in self.items():
if k != "frames":
if k != "sym_delay":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
else:
ans.append((k, float(v)))
norm_value = float(v) / num_frames
ans.append((k, norm_value))
return ans
def reduce(self, device):