Minor fixes.

This commit is contained in:
Fangjun Kuang 2022-05-20 00:43:16 +08:00
parent bcd0e872b8
commit 465803e219

View File

@ -288,18 +288,17 @@ class OfflineServer:
samples,
)
async def compute_encoder_out(
async def compute_and_decode(
self,
features: torch.Tensor,
) -> torch.Tensor:
"""Run the RNN-T encoder network.
) -> List[int]:
"""Run the RNN-T model on the features and do greedy search.
Args:
features:
A 2-D tensor of shape (num_frames, feature_dim).
Returns:
Return a 2-D tensor of shape (num_frames, encoder_out_dim) containing
the output of the encoder network.
Return a list of token IDs containing the decoded results.
"""
loop = asyncio.get_running_loop()
future = loop.create_future()
@ -322,7 +321,7 @@ class OfflineServer:
while True:
samples = await self.recv_audio_samples(socket)
features = await self.compute_features(samples)
hyp = await self.compute_encoder_out(features)
hyp = await self.compute_and_decode(features)
result = self.sp.decode(hyp)
logging.info(f"hyp: {result}")
await socket.send(result)