diff --git a/bin/offline_server.py b/bin/offline_server.py index 104815acf..1d64ca82f 100755 --- a/bin/offline_server.py +++ b/bin/offline_server.py @@ -288,18 +288,17 @@ class OfflineServer: samples, ) - async def compute_encoder_out( + async def compute_and_decode( self, features: torch.Tensor, - ) -> torch.Tensor: - """Run the RNN-T encoder network. + ) -> List[int]: + """Run the RNN-T model on the features and do greedy search. Args: features: A 2-D tensor of shape (num_frames, feature_dim). Returns: - Return a 2-D tensor of shape (num_frames, encoder_out_dim) containing - the output of the encoder network. + Return a list of token IDs containing the decoded results. """ loop = asyncio.get_running_loop() future = loop.create_future() @@ -322,7 +321,7 @@ class OfflineServer: while True: samples = await self.recv_audio_samples(socket) features = await self.compute_features(samples) - hyp = await self.compute_encoder_out(features) + hyp = await self.compute_and_decode(features) result = self.sp.decode(hyp) logging.info(f"hyp: {result}") await socket.send(result)