Henry Li Xinyuan b07d5472c5
Implement recipe for Fluent Speech Commands dataset (#1469)
---------

Signed-off-by: Xinyuan Li <xli257@c13.clsp.jhu.edu>
2024-01-31 22:53:36 +08:00

72 lines
2.1 KiB
Python
Executable File

# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from transducer.model import Transducer
def greedy_search(
model: Transducer, encoder_out: torch.Tensor, id2word: dict
) -> List[str]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
max_u = 1000 # terminate after this number of steps
u = 0
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (N, 1, 1)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
u += 1
else:
t += 1
# id2word = {1: "YES", 2: "NO"}
hyp = [id2word[i] for i in hyp]
return hyp