add @torch.jit.export for init_states function

This commit is contained in:
yaozengwei 2022-07-06 17:46:36 +08:00
parent dbea9a9970
commit 1f6c822dc0
2 changed files with 6 additions and 0 deletions

View File

@ -1636,6 +1636,7 @@ class EmformerEncoder(nn.Module):
)
return output, output_lengths, output_states
@torch.jit.export
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
attn_caches = [
@ -1684,6 +1685,7 @@ class Emformer(EncoderInterface):
self.subsampling_factor = subsampling_factor
self.right_context_length = right_context_length
self.chunk_length = chunk_length
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
if chunk_length % subsampling_factor != 0:
@ -1832,6 +1834,7 @@ class Emformer(EncoderInterface):
return output, output_lengths, output_states
@torch.jit.export
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
return self.encoder.init_states(device)

View File

@ -1539,6 +1539,7 @@ class EmformerEncoder(nn.Module):
)
return output, output_lengths, output_states
@torch.jit.export
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
attn_caches = [
@ -1587,6 +1588,7 @@ class Emformer(EncoderInterface):
self.subsampling_factor = subsampling_factor
self.right_context_length = right_context_length
self.chunk_length = chunk_length
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
if chunk_length % subsampling_factor != 0:
@ -1735,6 +1737,7 @@ class Emformer(EncoderInterface):
return output, output_lengths, output_states
@torch.jit.export
def init_states(self, device: torch.device = torch.device("cpu")):
"""Create initial states."""
return self.encoder.init_states(device)