Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-05-20 00:43:16 +08:00
parent bcd0e872b8
commit 465803e219

View File

@ -288,18 +288,17 @@ class OfflineServer:
samples, samples,
) )
async def compute_encoder_out( async def compute_and_decode(
self, self,
features: torch.Tensor, features: torch.Tensor,
) -> torch.Tensor: ) -> List[int]:
"""Run the RNN-T encoder network. """Run the RNN-T model on the features and do greedy search.
Args: Args:
features: features:
A 2-D tensor of shape (num_frames, feature_dim). A 2-D tensor of shape (num_frames, feature_dim).
Returns: Returns:
Return a 2-D tensor of shape (num_frames, encoder_out_dim) containing Return a list of token IDs containing the decoded results.
the output of the encoder network.
""" """
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = loop.create_future() future = loop.create_future()
@ -322,7 +321,7 @@ class OfflineServer:
while True: while True:
samples = await self.recv_audio_samples(socket) samples = await self.recv_audio_samples(socket)
features = await self.compute_features(samples) features = await self.compute_features(samples)
hyp = await self.compute_encoder_out(features) hyp = await self.compute_and_decode(features)
result = self.sp.decode(hyp) result = self.sp.decode(hyp)
logging.info(f"hyp: {result}") logging.info(f"hyp: {result}")
await socket.send(result) await socket.send(result)