mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
Minor fixes
This commit is contained in:
parent
09b0c54983
commit
1c794e32b0
@ -141,12 +141,20 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to export a streaming model, if the models in exp-dir
|
||||
are streaming model, this should be True.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
exporting a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -174,6 +182,9 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.streaming_model:
|
||||
assert params.causal_convolution
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -122,7 +122,7 @@ class Conformer(EncoderInterface):
|
||||
causal,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
|
||||
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -1396,7 +1396,7 @@ class ConvolutionModule(nn.Module):
|
||||
self,
|
||||
x: Tensor,
|
||||
cache: Optional[Tensor] = None,
|
||||
right_context=0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
|
@ -156,12 +156,20 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to export a streaming model, if the models in exp-dir
|
||||
are streaming model, this should be True.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
exporting a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -188,6 +196,9 @@ def main():
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.streaming_model:
|
||||
assert params.causal_convolution
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -480,7 +480,9 @@ def decode_dataset(
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
initial_states = model.get_init_state(params.left_context, device=device)
|
||||
initial_states = model.encoder.get_init_state(
|
||||
params.left_context, device=device
|
||||
)
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
decode_stream = DecodeStream(
|
||||
|
@ -125,6 +125,55 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
Note: not needed here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
Note: not needed for here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""How many left context can be seen in chunks when calculating attention.
|
||||
Note: not needed here, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to export a streaming model, if the models in exp-dir
|
||||
are streaming model, this should be True.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
exporting a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -148,6 +197,9 @@ def main():
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.streaming_model:
|
||||
assert params.causal_convolution
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -125,7 +125,7 @@ class Conformer(Transformer):
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
|
||||
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
|
Loading…
x
Reference in New Issue
Block a user