Merge pull request #5 from kakashidan/zipformer2_streaming

small fixes
This commit is contained in:
danfu 2023-06-11 23:55:14 +08:00 committed by GitHub
commit 3aeed46af2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 8 deletions

View File

@ -246,7 +246,7 @@ class OnnxEncoder(nn.Module):
new_processed_lens, new_processed_lens,
] ]
return encoder_out, encoder_out_lens, new_states return encoder_out, new_states
def get_init_states( def get_init_states(
self, self,
@ -266,7 +266,7 @@ class OnnxEncoder(nn.Module):
embed_states = self.encoder_embed.get_init_states(batch_size, device) embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states) states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device)
states.append(processed_lens) states.append(processed_lens)
return states return states

View File

@ -209,7 +209,7 @@ class OnnxModel:
self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2] self.states += [cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2]
embed_states = torch.zeros(batch_size, 128, 3, 19) embed_states = torch.zeros(batch_size, 128, 3, 19)
self.states.append(embed_states) self.states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32) processed_lens = torch.zeros(batch_size, dtype=torch.int64)
self.states.append(processed_lens) self.states.append(processed_lens)
self.num_encoders = num_encoders self.num_encoders = num_encoders
@ -296,7 +296,6 @@ class OnnxModel:
encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens encoder_input[f"{name}"] = processed_lens.numpy() if isinstance(processed_lens, torch.Tensor) else processed_lens
encoder_output.append(f"new_{name}") encoder_output.append(f"new_{name}")
logging.info(f"encoder_output_len={len(encoder_output)}")
return encoder_input, encoder_output return encoder_input, encoder_output
def _update_states(self, states: List[np.ndarray]): def _update_states(self, states: List[np.ndarray]):
@ -312,11 +311,8 @@ class OnnxModel:
T' is usually equal to ((T-7)//2+1)//2 T' is usually equal to ((T-7)//2+1)//2
""" """
encoder_input, encoder_output_names = self._build_encoder_input_output(x) encoder_input, encoder_output_names = self._build_encoder_input_output(x)
# logging.info(encoder_input.keys())
# logging.info(encoder_output_names)
out = self.encoder.run(encoder_output_names, encoder_input) out = self.encoder.run(encoder_output_names, encoder_input)
len = out.pop(1)
out.append(len.astype(np.int32))
self._update_states(out[1:]) self._update_states(out[1:])