mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
add @torch.jit.export for init_states function
This commit is contained in:
parent
dbea9a9970
commit
1f6c822dc0
@ -1636,6 +1636,7 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
return output, output_lengths, output_states
|
||||
|
||||
@torch.jit.export
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
attn_caches = [
|
||||
@ -1684,6 +1685,7 @@ class Emformer(EncoderInterface):
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.right_context_length = right_context_length
|
||||
self.chunk_length = chunk_length
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
if chunk_length % subsampling_factor != 0:
|
||||
@ -1832,6 +1834,7 @@ class Emformer(EncoderInterface):
|
||||
|
||||
return output, output_lengths, output_states
|
||||
|
||||
@torch.jit.export
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
return self.encoder.init_states(device)
|
||||
|
@ -1539,6 +1539,7 @@ class EmformerEncoder(nn.Module):
|
||||
)
|
||||
return output, output_lengths, output_states
|
||||
|
||||
@torch.jit.export
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
attn_caches = [
|
||||
@ -1587,6 +1588,7 @@ class Emformer(EncoderInterface):
|
||||
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.right_context_length = right_context_length
|
||||
self.chunk_length = chunk_length
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
if chunk_length % subsampling_factor != 0:
|
||||
@ -1735,6 +1737,7 @@ class Emformer(EncoderInterface):
|
||||
|
||||
return output, output_lengths, output_states
|
||||
|
||||
@torch.jit.export
|
||||
def init_states(self, device: torch.device = torch.device("cpu")):
|
||||
"""Create initial states."""
|
||||
return self.encoder.init_states(device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user