mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add @torch.jit.export for init_states function
This commit is contained in:
parent
dbea9a9970
commit
1f6c822dc0
@ -1636,6 +1636,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
return output, output_lengths, output_states
|
return output, output_lengths, output_states
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||||
"""Create initial states."""
|
"""Create initial states."""
|
||||||
attn_caches = [
|
attn_caches = [
|
||||||
@ -1684,6 +1685,7 @@ class Emformer(EncoderInterface):
|
|||||||
|
|
||||||
self.subsampling_factor = subsampling_factor
|
self.subsampling_factor = subsampling_factor
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
|
self.chunk_length = chunk_length
|
||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
if chunk_length % subsampling_factor != 0:
|
if chunk_length % subsampling_factor != 0:
|
||||||
@ -1832,6 +1834,7 @@ class Emformer(EncoderInterface):
|
|||||||
|
|
||||||
return output, output_lengths, output_states
|
return output, output_lengths, output_states
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||||
"""Create initial states."""
|
"""Create initial states."""
|
||||||
return self.encoder.init_states(device)
|
return self.encoder.init_states(device)
|
||||||
|
@ -1539,6 +1539,7 @@ class EmformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
return output, output_lengths, output_states
|
return output, output_lengths, output_states
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||||
"""Create initial states."""
|
"""Create initial states."""
|
||||||
attn_caches = [
|
attn_caches = [
|
||||||
@ -1587,6 +1588,7 @@ class Emformer(EncoderInterface):
|
|||||||
|
|
||||||
self.subsampling_factor = subsampling_factor
|
self.subsampling_factor = subsampling_factor
|
||||||
self.right_context_length = right_context_length
|
self.right_context_length = right_context_length
|
||||||
|
self.chunk_length = chunk_length
|
||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
if chunk_length % subsampling_factor != 0:
|
if chunk_length % subsampling_factor != 0:
|
||||||
@ -1735,6 +1737,7 @@ class Emformer(EncoderInterface):
|
|||||||
|
|
||||||
return output, output_lengths, output_states
|
return output, output_lengths, output_states
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||||
"""Create initial states."""
|
"""Create initial states."""
|
||||||
return self.encoder.init_states(device)
|
return self.encoder.init_states(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user