mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add streaming conformer
This commit is contained in:
parent
80a2e67b23
commit
227f32f089
@ -1 +0,0 @@
|
|||||||
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py
|
|
1444
egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
Normal file
1444
egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -176,6 +176,40 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -468,6 +502,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
|
short_chunk_size=params.short_chunk_size,
|
||||||
|
num_left_chunks=params.num_left_chunks,
|
||||||
|
causal=params.causal_convolution,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -958,6 +996,11 @@ def run(rank, world_size, args):
|
|||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
if params.dynamic_chunk_training:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "dynamic_chunk_training requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user