Merge pull request #5 from kakashidan/zipformer2_streaming

small fixes
This commit is contained in:
danfu 2023-06-11 23:55:14 +08:00 committed by GitHub
commit 3aeed46af2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 8 deletions

View File

@ -246,7 +246,7 @@ class OnnxEncoder(nn.Module):
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
return encoder_out, new_states
def get_init_states(
self,
@ -266,7 +266,7 @@ class OnnxEncoder(nn.Module):
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
states.append(processed_lens)
return states

View File

@ -209,7 +209,7 @@ class OnnxModel:
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
embed_states = torch.zeros(batch_size, 128, 3, 19)
self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32)
processed_lens = torch.zeros(batch_size, dtype=torch.int64)
self.states.append(processed_lens)
self.num_encoders = num_encoders
@ -296,7 +296,6 @@ class OnnxModel:
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
encoder_output.append(f"new_{name}")
logging.info(f"encoder_output_len={len(encoder_output)}")
return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]):
@ -312,11 +311,8 @@ class OnnxModel:
T' is usually equal to ((T-7)//2+1)//2
"""
encoder_input, encoder_output_names = self._build_encoder_input_output(x)
# logging.info(encoder_input.keys())
# logging.info(encoder_output_names)
out = self.encoder.run(encoder_output_names, encoder_input)
len = out.pop(1)
out.append(len.astype(np.int32))
self._update_states(out[1:])