mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-18 22:02:24 +00:00
minor fixes
This commit is contained in:
parent
ce73643af6
commit
2504036f5b
@ -4,6 +4,7 @@
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2023 (authors: Feiteng Li)
|
||||
# Copyright 2024 (authors: Yuekai Zhang)
|
||||
# Copyright 2024 Tsinghua University (authors: Zengrui Jin,)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -48,10 +49,8 @@ python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -686,9 +685,9 @@ def compute_validation_loss(
|
||||
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if isinstance(model, DDP):
|
||||
model.module.visualize(predicts, batch, output_dir=output_dir)
|
||||
model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||
else:
|
||||
model.visualize(predicts, batch, output_dir=output_dir)
|
||||
model.visualize(predicts, batch, tokenizer, output_dir=output_dir)
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
@ -23,6 +23,7 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tokenizer import TextTokenCollater
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
@ -1664,13 +1665,15 @@ class VALLE(nn.Module):
|
||||
self,
|
||||
predicts: Tuple[torch.Tensor],
|
||||
batch: Dict[str, Union[List, torch.Tensor]],
|
||||
tokenizer: TextTokenCollater,
|
||||
output_dir: str,
|
||||
limit: int = 4,
|
||||
) -> None:
|
||||
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
||||
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
||||
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
||||
audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()
|
||||
|
||||
tokens = batch["tokens"]
|
||||
text_tokens, text_tokens_lens = tokenizer(tokens)
|
||||
assert text_tokens.ndim == 2
|
||||
|
||||
utt_ids, texts = batch["utt_id"], batch["text"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user