Minor fixes

This commit is contained in:
pkufool 2022-06-06 20:19:49 +08:00
parent 09b0c54983
commit 1c794e32b0
6 changed files with 82 additions and 6 deletions

View File

@ -141,12 +141,20 @@ def get_parser():
""",
)
parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
exporting a streaming model.
""",
)
@ -174,6 +182,9 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
logging.info(params)
logging.info("About to create model")

View File

@ -122,7 +122,7 @@ class Conformer(EncoderInterface):
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
self._init_state: List[torch.Tensor] = [torch.empty(0)]
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -1396,7 +1396,7 @@ class ConvolutionModule(nn.Module):
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context=0,
right_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.

View File

@ -156,12 +156,20 @@ def get_parser():
""",
)
parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
exporting a streaming model.
""",
)
@ -188,6 +196,9 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
logging.info(params)
logging.info("About to create model")

View File

@ -480,7 +480,9 @@ def decode_dataset(
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.get_init_state(params.left_context, device=device)
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(

View File

@ -125,6 +125,55 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed here, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
exporting a streaming model.
""",
)
return parser
@ -148,6 +197,9 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
logging.info(params)
logging.info("About to create model")

View File

@ -125,7 +125,7 @@ class Conformer(Transformer):
# and throws an error without this change.
self.after_norm = identity
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
self._init_state: List[torch.Tensor] = [torch.empty(0)]
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor