mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add streaming conformer
This commit is contained in:
parent
80a2e67b23
commit
227f32f089
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
|
1444
egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
Normal file
1444
egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -176,6 +176,40 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="How many left context can be seen in chunks when calculating attention.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -468,6 +502,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||
short_chunk_size=params.short_chunk_size,
|
||||
num_left_chunks=params.num_left_chunks,
|
||||
causal=params.causal_convolution,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -958,6 +996,11 @@ def run(rank, world_size, args):
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
|
||||
if params.dynamic_chunk_training:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "dynamic_chunk_training requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
Loading…
x
Reference in New Issue
Block a user