mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add new phonemizer to infer.py
This commit is contained in:
parent
87e1d286cf
commit
16ccbc5bc8
@ -130,14 +130,16 @@ def infer_dataset(
|
||||
batch_size = len(batch["tokens"])
|
||||
|
||||
tokens = batch["tokens"]
|
||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||
tokens = tokenizer.tokens_to_token_ids(
|
||||
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||
|
||||
audio = batch["audio"]
|
||||
audio_lens = batch["audio_lens"].tolist()
|
||||
@ -201,8 +203,7 @@ def main():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.blank_id
|
||||
params.oov_id = tokenizer.oov_id
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user