mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-11 01:54:20 +00:00
Minor fixes
This commit is contained in:
parent
09b0c54983
commit
1c794e32b0
@ -141,12 +141,20 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to export a streaming model, if the models in exp-dir
|
||||||
|
are streaming model, this should be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-convolution",
|
"--causal-convolution",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
using dynamic_chunk_training.
|
exporting a streaming model.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,6 +182,9 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.streaming_model:
|
||||||
|
assert params.causal_convolution
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -122,7 +122,7 @@ class Conformer(EncoderInterface):
|
|||||||
causal,
|
causal,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
|
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -1396,7 +1396,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
cache: Optional[Tensor] = None,
|
cache: Optional[Tensor] = None,
|
||||||
right_context=0,
|
right_context: int = 0,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Compute convolution module.
|
"""Compute convolution module.
|
||||||
|
|
||||||
|
@ -156,12 +156,20 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to export a streaming model, if the models in exp-dir
|
||||||
|
are streaming model, this should be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--causal-convolution",
|
"--causal-convolution",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
using dynamic_chunk_training.
|
exporting a streaming model.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -188,6 +196,9 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.streaming_model:
|
||||||
|
assert params.causal_convolution
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -480,7 +480,9 @@ def decode_dataset(
|
|||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
initial_states = model.get_init_state(params.left_context, device=device)
|
initial_states = model.encoder.get_init_state(
|
||||||
|
params.left_context, device=device
|
||||||
|
)
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
|
@ -125,6 +125,55 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
Note: not needed here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
Note: not needed for here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""How many left context can be seen in chunks when calculating attention.
|
||||||
|
Note: not needed here, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to export a streaming model, if the models in exp-dir
|
||||||
|
are streaming model, this should be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
exporting a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -148,6 +197,9 @@ def main():
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.streaming_model:
|
||||||
|
assert params.causal_convolution
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -125,7 +125,7 @@ class Conformer(Transformer):
|
|||||||
# and throws an error without this change.
|
# and throws an error without this change.
|
||||||
self.after_norm = identity
|
self.after_norm = identity
|
||||||
|
|
||||||
self._init_state = torch.jit.Attribute([], List[torch.Tensor])
|
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
Loading…
x
Reference in New Issue
Block a user