mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
commit
3aeed46af2
@ -246,7 +246,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
new_processed_lens,
|
new_processed_lens,
|
||||||
]
|
]
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens, new_states
|
return encoder_out, new_states
|
||||||
|
|
||||||
def get_init_states(
|
def get_init_states(
|
||||||
self,
|
self,
|
||||||
@ -266,7 +266,7 @@ class OnnxEncoder(nn.Module):
|
|||||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||||
states.append(embed_states)
|
states.append(embed_states)
|
||||||
|
|
||||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||||
states.append(processed_lens)
|
states.append(processed_lens)
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
@ -209,7 +209,7 @@ class OnnxModel:
|
|||||||
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
|
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
|
||||||
embed_states = torch.zeros(batch_size, 128, 3, 19)
|
embed_states = torch.zeros(batch_size, 128, 3, 19)
|
||||||
self.states.append(embed_states)
|
self.states.append(embed_states)
|
||||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32)
|
processed_lens = torch.zeros(batch_size, dtype=torch.int64)
|
||||||
self.states.append(processed_lens)
|
self.states.append(processed_lens)
|
||||||
|
|
||||||
self.num_encoders = num_encoders
|
self.num_encoders = num_encoders
|
||||||
@ -296,7 +296,6 @@ class OnnxModel:
|
|||||||
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
|
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
|
||||||
encoder_output.append(f"new_{name}")
|
encoder_output.append(f"new_{name}")
|
||||||
|
|
||||||
logging.info(f"encoder_output_len={len(encoder_output)}")
|
|
||||||
return encoder_input, encoder_output
|
return encoder_input, encoder_output
|
||||||
|
|
||||||
def _update_states(self, states: List[np.ndarray]):
|
def _update_states(self, states: List[np.ndarray]):
|
||||||
@ -312,11 +311,8 @@ class OnnxModel:
|
|||||||
T' is usually equal to ((T-7)//2+1)//2
|
T' is usually equal to ((T-7)//2+1)//2
|
||||||
"""
|
"""
|
||||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||||
# logging.info(encoder_input.keys())
|
|
||||||
# logging.info(encoder_output_names)
|
|
||||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||||
len = out.pop(1)
|
|
||||||
out.append(len.astype(np.int32))
|
|
||||||
|
|
||||||
self._update_states(out[1:])
|
self._update_states(out[1:])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user