mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Add new phonemizer to infer.py
This commit is contained in:
parent
87e1d286cf
commit
16ccbc5bc8
@ -130,14 +130,16 @@ def infer_dataset(
|
|||||||
batch_size = len(batch["tokens"])
|
batch_size = len(batch["tokens"])
|
||||||
|
|
||||||
tokens = batch["tokens"]
|
tokens = batch["tokens"]
|
||||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
tokens = tokenizer.tokens_to_token_ids(
|
||||||
|
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||||
|
)
|
||||||
tokens = k2.RaggedTensor(tokens)
|
tokens = k2.RaggedTensor(tokens)
|
||||||
row_splits = tokens.shape.row_splits(1)
|
row_splits = tokens.shape.row_splits(1)
|
||||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||||
tokens = tokens.to(device)
|
tokens = tokens.to(device)
|
||||||
tokens_lens = tokens_lens.to(device)
|
tokens_lens = tokens_lens.to(device)
|
||||||
# tensor of shape (B, T)
|
# tensor of shape (B, T)
|
||||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||||
|
|
||||||
audio = batch["audio"]
|
audio = batch["audio"]
|
||||||
audio_lens = batch["audio_lens"].tolist()
|
audio_lens = batch["audio_lens"].tolist()
|
||||||
@ -201,8 +203,7 @@ def main():
|
|||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
tokenizer = Tokenizer(params.tokens)
|
tokenizer = Tokenizer(params.tokens)
|
||||||
params.blank_id = tokenizer.blank_id
|
params.blank_id = tokenizer.pad_id
|
||||||
params.oov_id = tokenizer.oov_id
|
|
||||||
params.vocab_size = tokenizer.vocab_size
|
params.vocab_size = tokenizer.vocab_size
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user