mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
small fixes
This commit is contained in:
parent
15c7035dad
commit
d84e086798
@ -246,7 +246,7 @@ class OnnxEncoder(nn.Module):
|
||||
new_processed_lens,
|
||||
]
|
||||
|
||||
return encoder_out, encoder_out_lens, new_states
|
||||
return encoder_out, new_states
|
||||
|
||||
def get_init_states(
|
||||
self,
|
||||
@ -266,7 +266,7 @@ class OnnxEncoder(nn.Module):
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
@ -209,7 +209,7 @@ class OnnxModel:
|
||||
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
|
||||
embed_states = torch.zeros(batch_size, 128, 3, 19)
|
||||
self.states.append(embed_states)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32)
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int64)
|
||||
self.states.append(processed_lens)
|
||||
|
||||
self.num_encoders = num_encoders
|
||||
@ -296,7 +296,6 @@ class OnnxModel:
|
||||
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
|
||||
encoder_output.append(f"new_{name}")
|
||||
|
||||
logging.info(f"encoder_output_len={len(encoder_output)}")
|
||||
return encoder_input, encoder_output
|
||||
|
||||
def _update_states(self, states: List[np.ndarray]):
|
||||
@ -312,11 +311,8 @@ class OnnxModel:
|
||||
T' is usually equal to ((T-7)//2+1)//2
|
||||
"""
|
||||
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
|
||||
# logging.info(encoder_input.keys())
|
||||
# logging.info(encoder_output_names)
|
||||
|
||||
out = self.encoder.run(encoder_output_names, encoder_input)
|
||||
len = out.pop(1)
|
||||
out.append(len.astype(np.int32))
|
||||
|
||||
self._update_states(out[1:])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user