mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 09:34:39 +00:00
Minor fixes
This commit is contained in:
parent
1c794e32b0
commit
d7be9bd9c5
@ -73,7 +73,7 @@ Usage:
|
|||||||
--avg 15 \
|
--avg 15 \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--causal-convolution 1 \
|
||||||
--right-chunk-size 16 \
|
--decode-chunk-size 16 \
|
||||||
--left-context 64 \
|
--left-context 64 \
|
||||||
--exp-dir ./pruned_transducer_stateless/exp \
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
@ -302,16 +302,9 @@ def get_parser():
|
|||||||
test a streaming model.
|
test a streaming model.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-convolution",
|
"--decode-chunk-size",
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--right-chunk-size",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
@ -379,7 +372,7 @@ def decode_one_batch(
|
|||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
states=[],
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
@ -610,7 +603,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
@ -646,9 +639,8 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
# Decoding in streaming requires causal convolution"
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ class DecodeStream(object):
|
|||||||
`get_init_state` in conformer.py
|
`get_init_state` in conformer.py
|
||||||
decoding_graph:
|
decoding_graph:
|
||||||
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
|
||||||
|
Used only when decoding_method is fast_beam_search.
|
||||||
device:
|
device:
|
||||||
The device to run this stream.
|
The device to run this stream.
|
||||||
"""
|
"""
|
||||||
@ -49,7 +50,8 @@ class DecodeStream(object):
|
|||||||
|
|
||||||
# It contains a 2-D tensors representing the feature frames.
|
# It contains a 2-D tensors representing the feature frames.
|
||||||
self.features: torch.Tensor = None
|
self.features: torch.Tensor = None
|
||||||
# how many frames are processed. (before subsampling).
|
# how many frames have been processed. (before subsampling).
|
||||||
|
# we only modify this value in `func:get_feature_frames`.
|
||||||
self.num_processed_frames: int = 0
|
self.num_processed_frames: int = 0
|
||||||
self._done: bool = False
|
self._done: bool = False
|
||||||
# The transcript of current utterance.
|
# The transcript of current utterance.
|
||||||
@ -57,6 +59,9 @@ class DecodeStream(object):
|
|||||||
# The decoding result (partial or final) of current utterance.
|
# The decoding result (partial or final) of current utterance.
|
||||||
self.hyp: List = []
|
self.hyp: List = []
|
||||||
|
|
||||||
|
# how many frames have been processed, after subsampling (i.e. a
|
||||||
|
# cumulative sum of the second return value of
|
||||||
|
# encoder.streaming_forward
|
||||||
self.feature_len: int = 0
|
self.feature_len: int = 0
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
@ -69,7 +74,7 @@ class DecodeStream(object):
|
|||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
False
|
False
|
||||||
), f"Decoding method :{params.decoding_method} do not support"
|
), f"Decoding method :{params.decoding_method} do not support."
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
|
@ -149,14 +149,6 @@ def get_parser():
|
|||||||
are streaming model, this should be True.
|
are streaming model, this should be True.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
exporting a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -183,7 +175,7 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
params.causal_convolution = True
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -66,8 +66,6 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
delay_penalty: float = 0.0,
|
|
||||||
return_sym_delay: bool = False,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -138,31 +136,10 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
delay_penalty=delay_penalty,
|
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sym_delay = None
|
|
||||||
if return_sym_delay:
|
|
||||||
B, S, T0 = px_grad.shape
|
|
||||||
T = T0 - 1
|
|
||||||
if boundary is None:
|
|
||||||
offset = torch.tensor(
|
|
||||||
(T - 1) / 2,
|
|
||||||
dtype=px_grad.dtype,
|
|
||||||
device=px_grad.device,
|
|
||||||
).expand(B, 1, 1)
|
|
||||||
total_syms = S * B
|
|
||||||
else:
|
|
||||||
offset = (boundary[:, 3] - 1) / 2
|
|
||||||
total_syms = torch.sum(boundary[:, 2])
|
|
||||||
offset = torch.arange(T0, device=px_grad.device).reshape(
|
|
||||||
1, 1, T0
|
|
||||||
) - offset.reshape(B, 1, 1)
|
|
||||||
sym_delay = px_grad * offset
|
|
||||||
sym_delay = torch.sum(sym_delay) / total_syms
|
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad=px_grad,
|
px_grad=px_grad,
|
||||||
@ -186,8 +163,7 @@ class Transducer(nn.Module):
|
|||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
delay_penalty=delay_penalty,
|
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss, sym_delay)
|
return (simple_loss, pruned_loss)
|
||||||
|
@ -20,9 +20,12 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless2/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
|
--decode-chunk-size 8 \
|
||||||
|
--left-context 32 \
|
||||||
|
--right-context 2 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 200
|
--num-decode-streams 1000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -182,15 +185,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -205,6 +199,13 @@ def get_parser():
|
|||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--right-context",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="right context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-decode-streams",
|
"--num-decode-streams",
|
||||||
type=int,
|
type=int,
|
||||||
@ -343,14 +344,18 @@ def decode_one_chunk(
|
|||||||
processed_feature_lens = []
|
processed_feature_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
params.decode_chunk_size * params.subsampling_factor
|
(params.decode_chunk_size + 2 + params.right_context)
|
||||||
|
* params.subsampling_factor
|
||||||
)
|
)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
|
processed_feature_lens.append(stream.feature_len)
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
processed_feature_lens.append(stream.feature_len)
|
|
||||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||||
|
|
||||||
feature_lens = torch.tensor(feature_lens, device=device)
|
feature_lens = torch.tensor(feature_lens, device=device)
|
||||||
@ -358,15 +363,21 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
# if T is less than 7 there will be an error in time reduction layer,
|
# if T is less than 7 there will be an error in time reduction layer,
|
||||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||||
if features.size(1) < 7:
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
feature_lens += 7 - features.size(1)
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
|
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||||
|
if features.size(1) < tail_length:
|
||||||
|
feature_lens += tail_length - features.size(1)
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
[
|
[
|
||||||
features,
|
features,
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
LOG_EPS, dtype=features.dtype, device=device
|
LOG_EPS, dtype=features.dtype, device=device
|
||||||
).expand(
|
).expand(
|
||||||
features.size(0), 7 - features.size(1), features.size(2)
|
features.size(0),
|
||||||
|
tail_length - features.size(1),
|
||||||
|
features.size(2),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
dim=1,
|
dim=1,
|
||||||
@ -377,12 +388,16 @@ def decode_one_chunk(
|
|||||||
torch.stack([x[1] for x in states], dim=2),
|
torch.stack([x[1] for x in states], dim=2),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
processed_feature_lens = torch.tensor(processed_feature_lens, device=device)
|
||||||
|
|
||||||
# Note: states will be modified in streaming_forward.
|
# Note: states will be modified in streaming_forward.
|
||||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=states,
|
states=states,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
|
right_context=params.right_context,
|
||||||
|
processed_lens=processed_feature_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
@ -395,9 +410,6 @@ def decode_one_chunk(
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
processed_feature_lens = torch.tensor(
|
|
||||||
processed_feature_lens, device=device
|
|
||||||
)
|
|
||||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||||
processed_lens = processed_feature_lens + encoder_out_lens
|
processed_lens = processed_feature_lens + encoder_out_lens
|
||||||
hyp_tokens = fast_beam_search(
|
hyp_tokens = fast_beam_search(
|
||||||
@ -411,8 +423,8 @@ def decode_one_chunk(
|
|||||||
finished_streams = []
|
finished_streams = []
|
||||||
for i in range(len(decode_streams)):
|
for i in range(len(decode_streams)):
|
||||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||||
|
decode_streams[i].feature_len += encoder_out_lens[i]
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decode_streams[i].feature_len += encoder_out_lens[i]
|
|
||||||
decode_streams[i].hyp = hyp_tokens[i]
|
decode_streams[i].hyp = hyp_tokens[i]
|
||||||
if decode_streams[i].done:
|
if decode_streams[i].done:
|
||||||
finished_streams.append(i)
|
finished_streams.append(i)
|
||||||
@ -457,12 +469,14 @@ def decode_dataset(
|
|||||||
opts.frame_opts.samp_freq = 16000
|
opts.frame_opts.samp_freq = 16000
|
||||||
opts.mel_opts.num_bins = 80
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
log_interval = 300
|
log_interval = 50
|
||||||
|
|
||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
initial_states = model.get_init_state(params.left_context, device=device)
|
initial_states = model.encoder.get_init_state(
|
||||||
|
params.left_context, device=device
|
||||||
|
)
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
@ -536,7 +550,7 @@ def decode_dataset(
|
|||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||||
):
|
):
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
@ -597,6 +611,7 @@ def main():
|
|||||||
# for streaming
|
# for streaming
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
params.suffix += f"-right-context-{params.right_context}"
|
||||||
|
|
||||||
# for fast_beam_search
|
# for fast_beam_search
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -620,10 +635,7 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
params.causal_convolution = True
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir pruned_transducer_stateless/exp \
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--dynamic-chunk-training 1 \
|
--dynamic-chunk-training 1 \
|
||||||
--causal-convolution 1 \
|
|
||||||
--short-chunk-size 25 \
|
--short-chunk-size 25 \
|
||||||
--num-left-chunks 4 \
|
--num-left-chunks 4 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
@ -244,15 +243,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--short-chunk-size",
|
"--short-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -269,25 +259,6 @@ def get_parser():
|
|||||||
help="How many left context can be seen in chunks when calculating attention.",
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--delay-penalty",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="""A constant value to penalize symbol delay, this may be
|
|
||||||
needed when training with time masking, to avoid the time masking
|
|
||||||
encouraging the network to delay symbols.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--return-sym-delay",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to return `sym_delay` during training, this is a stat
|
|
||||||
to measure symbols emission delay, especially for time masking training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -554,17 +525,14 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
sym_delay = None
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, sym_delay = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
delay_penalty=params.delay_penalty,
|
|
||||||
return_sym_delay=params.return_sym_delay,
|
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
loss = params.simple_loss_scale * simple_loss + pruned_loss
|
||||||
|
|
||||||
@ -582,9 +550,6 @@ def compute_loss(
|
|||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
if sym_delay is not None:
|
|
||||||
info["sym_delay"] = sym_delay.detatch().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
@ -839,9 +804,8 @@ def run(rank, world_size, args):
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.dynamic_chunk_training:
|
if params.dynamic_chunk_training:
|
||||||
assert (
|
# dynamic_chunk_training requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "dynamic_chunk_training requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -279,19 +279,26 @@ class Conformer(EncoderInterface):
|
|||||||
The chunk size for decoding, this will be used to simulate streaming
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
decoding using masking.
|
decoding using masking.
|
||||||
left_context:
|
left_context:
|
||||||
How many old frames the attention can see in current chunk, it MUST
|
How many previous frames the attention can see in current chunk.
|
||||||
be equal to left_context in decode_states.
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
simulate_streaming:
|
simulate_streaming:
|
||||||
If setting True, it will use a masking strategy to simulate streaming
|
If setting True, it will use a masking strategy to simulate streaming
|
||||||
fashion (i.e. every chunk data only see limited left context and
|
fashion (i.e. every chunk data only see limited left context and
|
||||||
right context). The whole sequence is supposed to be send at a time
|
right context). The whole sequence is supposed to be send at a time
|
||||||
When using simulate_streaming.
|
When using simulate_streaming.
|
||||||
|
processed_lens:
|
||||||
|
How many frames (after subsampling) have been processed for each sequence.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `logits` before padding.
|
||||||
- decode_states, the updated DecodeStates including the information
|
- decode_states, the updated states including the information
|
||||||
of current chunk.
|
of current chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -321,8 +328,6 @@ class Conformer(EncoderInterface):
|
|||||||
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||||
given {states[1].shape}."""
|
given {states[1].shape}."""
|
||||||
|
|
||||||
# src_key_padding_mask = make_pad_mask(lengths + left_context)
|
|
||||||
|
|
||||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
@ -341,6 +346,8 @@ class Conformer(EncoderInterface):
|
|||||||
|
|
||||||
embed = self.encoder_embed(x)
|
embed = self.encoder_embed(x)
|
||||||
|
|
||||||
|
# cut off 1 frame on each size of embed as they see the padding
|
||||||
|
# value which causes a training and decoding mismatch.
|
||||||
embed = embed[:, 1:-1, :]
|
embed = embed[:, 1:-1, :]
|
||||||
|
|
||||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
@ -359,7 +366,8 @@ class Conformer(EncoderInterface):
|
|||||||
x = x[0:-right_context, ...]
|
x = x[0:-right_context, ...]
|
||||||
lengths -= right_context
|
lengths -= right_context
|
||||||
else:
|
else:
|
||||||
|
# this branch simulates streaming decoding using mask as we are
|
||||||
|
# using in training time.
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
@ -558,9 +566,14 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
left_context:
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
How many previous frames the attention can see in current chunk.
|
||||||
it MUST be 0.
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
@ -708,10 +721,14 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||||
bypass layers more frequently.
|
bypass layers more frequently.
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
left_context:
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
How many previous frames the attention can see in current chunk.
|
||||||
it MUST be 0.
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
@ -1273,9 +1290,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
and attn_mask.dtype == torch.bool
|
and attn_mask.dtype == torch.bool
|
||||||
and key_padding_mask is not None
|
and key_padding_mask is not None
|
||||||
):
|
):
|
||||||
combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze(
|
if attn_mask.size(0) != 1:
|
||||||
1
|
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
|
||||||
).unsqueeze(2)
|
combined_mask = attn_mask | key_padding_mask.unsqueeze(
|
||||||
|
1
|
||||||
|
).unsqueeze(2)
|
||||||
|
else:
|
||||||
|
# attn_mask.shape == (1, tgt_len, src_len)
|
||||||
|
combined_mask = attn_mask.unsqueeze(
|
||||||
|
0
|
||||||
|
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
attn_output_weights = attn_output_weights.view(
|
attn_output_weights = attn_output_weights.view(
|
||||||
bsz, num_heads, tgt_len, src_len
|
bsz, num_heads, tgt_len, src_len
|
||||||
)
|
)
|
||||||
@ -1404,6 +1429,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
x: Input tensor (#time, batch, channels).
|
x: Input tensor (#time, batch, channels).
|
||||||
cache: The cache of depthwise_conv, only used in real streaming
|
cache: The cache of depthwise_conv, only used in real streaming
|
||||||
decoding.
|
decoding.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
If cache is None return the output tensor (#time, batch, channels).
|
If cache is None return the output tensor (#time, batch, channels).
|
||||||
|
@ -59,8 +59,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--decode-chunk-size 16 \
|
||||||
--right-chunk-size 16 \
|
|
||||||
--left-context 64 \
|
--left-context 64 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
@ -257,16 +256,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-convolution",
|
"--decode-chunk-size",
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--right-chunk-size",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
@ -335,11 +325,11 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
encoder_out, encoder_out_lens = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
states=[],
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
@ -561,7 +551,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
@ -594,9 +584,8 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
# Decoding in streaming requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -164,14 +164,6 @@ def get_parser():
|
|||||||
are streaming model, this should be True.
|
are streaming model, this should be True.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
exporting a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -197,7 +189,7 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
params.causal_convolution = True
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -78,8 +78,6 @@ class Transducer(nn.Module):
|
|||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup: float = 1.0,
|
warmup: float = 1.0,
|
||||||
delay_penalty: float = 0.0,
|
|
||||||
return_sym_delay: bool = False,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -157,31 +155,10 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
am_only_scale=am_scale,
|
am_only_scale=am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
delay_penalty=delay_penalty,
|
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
sym_delay = None
|
|
||||||
if return_sym_delay:
|
|
||||||
B, S, T0 = px_grad.shape
|
|
||||||
T = T0 - 1
|
|
||||||
if boundary is None:
|
|
||||||
offset = torch.tensor(
|
|
||||||
(T - 1) / 2,
|
|
||||||
dtype=px_grad.dtype,
|
|
||||||
device=px_grad.device,
|
|
||||||
).expand(B, 1, 1)
|
|
||||||
total_syms = S * B
|
|
||||||
else:
|
|
||||||
offset = (boundary[:, 3] - 1) / 2
|
|
||||||
total_syms = torch.sum(boundary[:, 2])
|
|
||||||
offset = torch.arange(T0, device=px_grad.device).reshape(
|
|
||||||
1, 1, T0
|
|
||||||
) - offset.reshape(B, 1, 1)
|
|
||||||
sym_delay = px_grad * offset
|
|
||||||
sym_delay = torch.sum(sym_delay) / total_syms
|
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad=px_grad,
|
px_grad=px_grad,
|
||||||
@ -210,9 +187,8 @@ class Transducer(nn.Module):
|
|||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
ranges=ranges,
|
ranges=ranges,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
delay_penalty=delay_penalty,
|
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss, sym_delay)
|
return (simple_loss, pruned_loss)
|
||||||
|
@ -20,9 +20,12 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless2/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
|
--left-context 32 \
|
||||||
|
--decode-chunk-size 8 \
|
||||||
|
--right-context 2 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 200
|
--num-decode-streams 1000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -182,15 +185,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -356,6 +350,9 @@ def decode_one_chunk(
|
|||||||
processed_feature_lens = []
|
processed_feature_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
(params.decode_chunk_size + 2 + params.right_context)
|
||||||
* params.subsampling_factor
|
* params.subsampling_factor
|
||||||
@ -372,7 +369,7 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
# if T is less than 7 there will be an error in time reduction layer,
|
# if T is less than 7 there will be an error in time reduction layer,
|
||||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||||
tail_length = 15 + params.right_context * params.subsampling_factor
|
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||||
if features.size(1) < tail_length:
|
if features.size(1) < tail_length:
|
||||||
feature_lens += tail_length - features.size(1)
|
feature_lens += tail_length - features.size(1)
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
@ -642,10 +639,8 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
# Decoding in streaming requires causal convolution
|
||||||
assert (
|
params.causal_convolution = True
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -48,11 +48,9 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir pruned_transducer_stateless/exp \
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--dynamic-chunk-training 1 \
|
--dynamic-chunk-training 1 \
|
||||||
--causal-convolution 1 \
|
|
||||||
--short-chunk-size 25 \
|
--short-chunk-size 25 \
|
||||||
--num-left-chunks 4 \
|
--num-left-chunks 4 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -285,15 +283,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--short-chunk-size",
|
"--short-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -310,25 +299,6 @@ def get_parser():
|
|||||||
help="How many left context can be seen in chunks when calculating attention.",
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--delay-penalty",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="""A constant value to penalize symbol delay, this may be
|
|
||||||
needed when training with time masking, to avoid the time masking
|
|
||||||
encouraging the network to delay symbols.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--return-sym-delay",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to return `sym_delay` during training, this is a stat
|
|
||||||
to measure symbols emission delay, especially for time masking training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -611,7 +581,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, sym_delay = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -619,8 +589,6 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
delay_penalty=params.delay_penalty,
|
|
||||||
return_sym_delay=params.return_sym_delay,
|
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -650,9 +618,6 @@ def compute_loss(
|
|||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
if params.return_sym_delay:
|
|
||||||
info["sym_delay"] = sym_delay.detach().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
@ -882,13 +847,8 @@ def run(rank, world_size, args):
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.dynamic_chunk_training:
|
if params.dynamic_chunk_training:
|
||||||
assert (
|
# dynamic_chunk_training requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "dynamic_chunk_training requires causal convolution"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
params.delay_penalty == 0.0
|
|
||||||
), "delay_penalty is intended for dynamic_chunk_training"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -266,16 +266,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-convolution",
|
"--decode-chunk-size",
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--right-chunk-size",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
@ -348,7 +339,7 @@ def decode_one_batch(
|
|||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
states=[],
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
@ -596,7 +587,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -635,9 +626,8 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
# Decoding in streaming requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -165,14 +165,6 @@ def get_parser():
|
|||||||
are streaming model, this should be True.
|
are streaming model, this should be True.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
exporting a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -198,7 +190,7 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
assert params.causal_convolution
|
params.causal_convolution = True
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,9 +20,12 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless2/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
|
--left-context 32 \
|
||||||
|
--decode-chunk-size 8 \
|
||||||
|
--right-context 2 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 200
|
--num-decode-streams 1000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -183,15 +186,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -357,6 +351,9 @@ def decode_one_chunk(
|
|||||||
processed_feature_lens = []
|
processed_feature_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
(params.decode_chunk_size + 2 + params.right_context)
|
||||||
* params.subsampling_factor
|
* params.subsampling_factor
|
||||||
@ -373,7 +370,10 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
# if T is less than 7 there will be an error in time reduction layer,
|
# if T is less than 7 there will be an error in time reduction layer,
|
||||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||||
tail_length = 15 + params.right_context * params.subsampling_factor
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
|
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||||
if features.size(1) < tail_length:
|
if features.size(1) < tail_length:
|
||||||
feature_lens += tail_length - features.size(1)
|
feature_lens += tail_length - features.size(1)
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
@ -481,7 +481,9 @@ def decode_dataset(
|
|||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
initial_states = model.get_init_state(params.left_context, device=device)
|
initial_states = model.encoder.get_init_state(
|
||||||
|
params.left_context, device=device
|
||||||
|
)
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
@ -641,10 +643,8 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
# Decoding in streaming requires causal convolution
|
||||||
assert (
|
params.causal_convolution = True
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -295,15 +295,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--short-chunk-size",
|
"--short-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -320,25 +311,6 @@ def get_parser():
|
|||||||
help="How many left context can be seen in chunks when calculating attention.",
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--delay-penalty",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="""A constant value to penalize symbol delay, this may be
|
|
||||||
needed when training with time masking, to avoid the time masking
|
|
||||||
encouraging the network to delay symbols.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--return-sym-delay",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to return `sym_delay` during training, this is a stat
|
|
||||||
to measure symbols emission delay, especially for time masking training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -963,13 +935,8 @@ def run(rank, world_size, args):
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.dynamic_chunk_training:
|
if params.dynamic_chunk_training:
|
||||||
assert (
|
# dynamic_chunk_training requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "dynamic_chunk_training requires causal convolution"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
params.delay_penalty == 0.0
|
|
||||||
), "delay_penalty is intended for dynamic_chunk_training"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -60,8 +60,7 @@ Usage:
|
|||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--simulate-streaming 1 \
|
--simulate-streaming 1 \
|
||||||
--causal-convolution 1 \
|
--decode-chunk-size 16 \
|
||||||
--right-chunk-size 16 \
|
|
||||||
--left-context 64 \
|
--left-context 64 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
@ -269,16 +268,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-convolution",
|
"--decode-chunk-size",
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--right-chunk-size",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
@ -351,7 +341,7 @@ def decode_one_batch(
|
|||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
states=[],
|
states=[],
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.decode_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
simulate_streaming=True,
|
simulate_streaming=True,
|
||||||
)
|
)
|
||||||
@ -573,7 +563,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
@ -609,9 +599,8 @@ def main():
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
if params.simulate_streaming:
|
||||||
assert (
|
# Decoding in streaming requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,6 +20,9 @@ Usage:
|
|||||||
./pruned_transducer_stateless2/streaming_decode.py \
|
./pruned_transducer_stateless2/streaming_decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
|
--left-context 32 \
|
||||||
|
--decode-chunk-size 8 \
|
||||||
|
--right-context 2 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--decoding_method greedy_search \
|
--decoding_method greedy_search \
|
||||||
--num-decode-streams 200
|
--num-decode-streams 200
|
||||||
@ -194,15 +197,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decode-chunk-size",
|
"--decode-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -368,6 +362,9 @@ def decode_one_chunk(
|
|||||||
processed_feature_lens = []
|
processed_feature_lens = []
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
feat, feat_len = stream.get_feature_frames(
|
feat, feat_len = stream.get_feature_frames(
|
||||||
(params.decode_chunk_size + 2 + params.right_context)
|
(params.decode_chunk_size + 2 + params.right_context)
|
||||||
* params.subsampling_factor
|
* params.subsampling_factor
|
||||||
@ -384,7 +381,10 @@ def decode_one_chunk(
|
|||||||
|
|
||||||
# if T is less than 7 there will be an error in time reduction layer,
|
# if T is less than 7 there will be an error in time reduction layer,
|
||||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||||
tail_length = 15 + params.right_context * params.subsampling_factor
|
# we plus 2 here because we will cut off one frame on each size of
|
||||||
|
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||||
|
# frames.
|
||||||
|
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||||
if features.size(1) < tail_length:
|
if features.size(1) < tail_length:
|
||||||
feature_lens += tail_length - features.size(1)
|
feature_lens += tail_length - features.size(1)
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
@ -492,7 +492,9 @@ def decode_dataset(
|
|||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
initial_states = model.get_init_state(params.left_context, device=device)
|
initial_states = model.encoder.get_init_state(
|
||||||
|
params.left_context, device=device
|
||||||
|
)
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
@ -655,10 +657,8 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
# Decoding in streaming requires causal convolution
|
||||||
assert (
|
params.causal_convolution = True
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -49,7 +49,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir pruned_transducer_stateless4/exp \
|
--exp-dir pruned_transducer_stateless4/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--dynamic-chunk-training 1 \
|
--dynamic-chunk-training 1 \
|
||||||
--causal-convolution 1 \
|
|
||||||
--short-chunk-size 25 \
|
--short-chunk-size 25 \
|
||||||
--num-left-chunks 4 \
|
--num-left-chunks 4 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
@ -302,15 +301,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--short-chunk-size",
|
"--short-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
@ -327,25 +317,6 @@ def get_parser():
|
|||||||
help="How many left context can be seen in chunks when calculating attention.",
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--delay-penalty",
|
|
||||||
type=float,
|
|
||||||
default=0.0,
|
|
||||||
help="""A constant value to penalize symbol delay, this may be
|
|
||||||
needed when training with time masking, to avoid the time masking
|
|
||||||
encouraging the network to delay symbols.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--return-sym-delay",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to return `sym_delay` during training, this is a stat
|
|
||||||
to measure symbols emission delay, especially for time masking training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -640,7 +611,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, sym_delay = model(
|
simple_loss, pruned_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -648,8 +619,6 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
delay_penalty=params.delay_penalty,
|
|
||||||
return_sym_delay=params.return_sym_delay,
|
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -679,9 +648,6 @@ def compute_loss(
|
|||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
if params.return_sym_delay:
|
|
||||||
info["sym_delay"] = sym_delay.detach().cpu().item()
|
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
@ -922,13 +888,8 @@ def run(rank, world_size, args):
|
|||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.dynamic_chunk_training:
|
if params.dynamic_chunk_training:
|
||||||
assert (
|
# dynamic_chunk_training requires causal convolution
|
||||||
params.causal_convolution
|
params.causal_convolution = True
|
||||||
), "dynamic_chunk_training requires causal convolution"
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
params.delay_penalty == 0.0
|
|
||||||
), "delay_penalty is intended for dynamic_chunk_training"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -248,7 +248,9 @@ class Conformer(Transformer):
|
|||||||
states: List[torch.Tensor],
|
states: List[torch.Tensor],
|
||||||
chunk_size: int = 16,
|
chunk_size: int = 16,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
|
right_context: int = 0,
|
||||||
simulate_streaming: bool = False,
|
simulate_streaming: bool = False,
|
||||||
|
processed_lens: Optional[Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -268,13 +270,20 @@ class Conformer(Transformer):
|
|||||||
The chunk size for decoding, this will be used to simulate streaming
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
decoding using masking.
|
decoding using masking.
|
||||||
left_context:
|
left_context:
|
||||||
How many old frames the attention can see in current chunk, it MUST
|
How many previous frames the attention can see in current chunk.
|
||||||
be equal to left_context in decode_states.
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
simulate_streaming:
|
simulate_streaming:
|
||||||
If setting True, it will use a masking strategy to simulate streaming
|
If setting True, it will use a masking strategy to simulate streaming
|
||||||
fashion (i.e. every chunk data only see limited left context and
|
fashion (i.e. every chunk data only see limited left context and
|
||||||
right context). The whole sequence is supposed to be send at a time
|
right context). The whole sequence is supposed to be send at a time
|
||||||
When using simulate_streaming.
|
When using simulate_streaming.
|
||||||
|
processed_lens:
|
||||||
|
How many frames (after subsampling) have been processed for each sequence.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
@ -310,9 +319,27 @@ class Conformer(Transformer):
|
|||||||
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||||
given {states[1].shape}."""
|
given {states[1].shape}."""
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(lengths + left_context)
|
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||||
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
|
assert processed_lens is not None
|
||||||
|
|
||||||
|
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||||
|
x.size(0), left_context
|
||||||
|
)
|
||||||
|
processed_lens = processed_lens.view(x.size(0), 1)
|
||||||
|
processed_mask = (processed_lens <= processed_mask).flip(1)
|
||||||
|
|
||||||
|
src_key_padding_mask = torch.cat(
|
||||||
|
[processed_mask, src_key_padding_mask], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
embed = self.encoder_embed(x)
|
embed = self.encoder_embed(x)
|
||||||
|
|
||||||
|
# cut off 1 frame on each size of embed as they see the padding
|
||||||
|
# value which causes a training and decoding mismatch.
|
||||||
|
embed = embed[:, 1:-1, :]
|
||||||
|
|
||||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
@ -322,6 +349,7 @@ class Conformer(Transformer):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
states=states,
|
states=states,
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
|
right_context=right_context,
|
||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
else:
|
else:
|
||||||
src_key_padding_mask = make_pad_mask(lengths)
|
src_key_padding_mask = make_pad_mask(lengths)
|
||||||
@ -512,6 +540,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
left_context: int = 0,
|
left_context: int = 0,
|
||||||
|
right_context: int = 0,
|
||||||
) -> Tuple[Tensor, List[Tensor]]:
|
) -> Tuple[Tensor, List[Tensor]]:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -528,9 +557,14 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
Note: states will be modified in this function.
|
Note: states will be modified in this function.
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
left_context:
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
How many previous frames the attention can see in current chunk.
|
||||||
it MUST be 0.
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
@ -562,7 +596,12 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
# separately) if needed.
|
# separately) if needed.
|
||||||
key = torch.cat([states[0], src], dim=0)
|
key = torch.cat([states[0], src], dim=0)
|
||||||
val = key
|
val = key
|
||||||
states[0] = key[-left_context:, ...]
|
if right_context > 0:
|
||||||
|
states[0] = key[
|
||||||
|
-(left_context + right_context) : -right_context, ... # noqa
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
states[0] = key[-left_context:, ...]
|
||||||
|
|
||||||
src_att = self.self_attn(
|
src_att = self.self_attn(
|
||||||
src,
|
src,
|
||||||
@ -582,7 +621,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
src = self.norm_conv(src)
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
src, conv_cache = self.conv_module(src, states[1])
|
src, conv_cache = self.conv_module(
|
||||||
|
src, states[1], right_context=right_context
|
||||||
|
)
|
||||||
states[1] = conv_cache
|
states[1] = conv_cache
|
||||||
src = residual + self.dropout(src)
|
src = residual + self.dropout(src)
|
||||||
|
|
||||||
@ -669,6 +710,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
left_context: int = 0,
|
left_context: int = 0,
|
||||||
|
right_context: int = 0,
|
||||||
) -> Tuple[Tensor, List[Tensor]]:
|
) -> Tuple[Tensor, List[Tensor]]:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -684,9 +726,14 @@ class ConformerEncoder(nn.Module):
|
|||||||
Note: states will be modified in this function.
|
Note: states will be modified in this function.
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
left_context: left context (in frames) used during streaming decoding.
|
left_context:
|
||||||
this is used only in real streaming decoding, in other circumstances,
|
How many previous frames the attention can see in current chunk.
|
||||||
it MUST be 0.
|
Note: It's not that each individual frame has `left_context` frames
|
||||||
|
of left context, some have more.
|
||||||
|
right_context:
|
||||||
|
How many future frames the attention can see in current chunk.
|
||||||
|
Note: It's not that each individual frame has `right_context` frames
|
||||||
|
of right context, some have more.
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||||
@ -707,6 +754,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
|
right_context=right_context,
|
||||||
)
|
)
|
||||||
states[0][layer_index] = cache[0]
|
states[0][layer_index] = cache[0]
|
||||||
states[1][layer_index] = cache[1]
|
states[1][layer_index] = cache[1]
|
||||||
@ -1329,7 +1377,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
self.activation = Swish()
|
self.activation = Swish()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, cache: Optional[Tensor] = None
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
cache: Optional[Tensor] = None,
|
||||||
|
right_context: int = 0,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
@ -1359,7 +1410,15 @@ class ConvolutionModule(nn.Module):
|
|||||||
), "Cache should be None in training time"
|
), "Cache should be None in training time"
|
||||||
assert cache.size(0) == self.lorder
|
assert cache.size(0) == self.lorder
|
||||||
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||||
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
if right_context > 0:
|
||||||
|
cache = x.permute(2, 0, 1)[
|
||||||
|
-(self.lorder + right_context) : ( # noqa
|
||||||
|
-right_context
|
||||||
|
),
|
||||||
|
...,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||||
|
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
# x is (batch, channels, time)
|
# x is (batch, channels, time)
|
||||||
|
@ -535,11 +535,8 @@ class MetricsTracker(collections.defaultdict):
|
|||||||
ans = []
|
ans = []
|
||||||
for k, v in self.items():
|
for k, v in self.items():
|
||||||
if k != "frames":
|
if k != "frames":
|
||||||
if k != "sym_delay":
|
norm_value = float(v) / num_frames
|
||||||
norm_value = float(v) / num_frames
|
ans.append((k, norm_value))
|
||||||
ans.append((k, norm_value))
|
|
||||||
else:
|
|
||||||
ans.append((k, float(v)))
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
def reduce(self, device):
|
def reduce(self, device):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user