add streaming conformer

This commit is contained in:
luomingshuang 2022-07-03 18:12:39 +08:00
parent 80a2e67b23
commit 227f32f089
2 changed files with 1487 additions and 1 deletions

View File

@ -1 +0,0 @@
../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py

File diff suppressed because it is too large Load Diff

View File

@ -176,6 +176,40 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -468,6 +502,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
@ -958,6 +996,11 @@ def run(rank, world_size, args):
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
logging.info(params)
logging.info("About to create model")