mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
modify usage of tokenizer in vits/train.py
This commit is contained in:
parent
1851443801
commit
595d4a3c47
@ -296,14 +296,16 @@ def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device):
|
||||
features_lens = batch["features_lens"].to(device)
|
||||
tokens = batch["tokens"]
|
||||
|
||||
tokens = tokenizer.tokens_to_token_ids(tokens)
|
||||
tokens = tokenizer.tokens_to_token_ids(
|
||||
tokens, intersperse_blank=True, add_sos=True, add_eos=True
|
||||
)
|
||||
tokens = k2.RaggedTensor(tokens)
|
||||
row_splits = tokens.shape.row_splits(1)
|
||||
tokens_lens = row_splits[1:] - row_splits[:-1]
|
||||
tokens = tokens.to(device)
|
||||
tokens_lens = tokens_lens.to(device)
|
||||
# a tensor of shape (B, T)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.blank_id)
|
||||
tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id)
|
||||
|
||||
return audio, audio_lens, features, features_lens, tokens, tokens_lens
|
||||
|
||||
@ -742,8 +744,7 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
tokenizer = Tokenizer(params.tokens)
|
||||
params.blank_id = tokenizer.blank_id
|
||||
params.oov_id = tokenizer.oov_id
|
||||
params.blank_id = tokenizer.pad_id
|
||||
params.vocab_size = tokenizer.vocab_size
|
||||
|
||||
logging.info(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user