diff --git a/egs/tokenizer/CODEC/encodec/train.py b/egs/tokenizer/CODEC/encodec/train.py index f74be7d4c..61a5e8536 100755 --- a/egs/tokenizer/CODEC/encodec/train.py +++ b/egs/tokenizer/CODEC/encodec/train.py @@ -329,11 +329,15 @@ def get_model(params: AttributeDict) -> nn.Module: params.update(inference_params) hop_length = np.prod(params.ratios) - n_q = int( - 1000 - * params.target_bandwidths[-1] - // (math.ceil(params.sampling_rate / hop_length) * 10) - ) if params.n_q is None else params.n_q + n_q = ( + int( + 1000 + * params.target_bandwidths[-1] + // (math.ceil(params.sampling_rate / hop_length) * 10) + ) + if params.n_q is None + else params.n_q + ) encoder = SEANetEncoder( n_filters=params.generator_n_filters, @@ -660,13 +664,17 @@ def train_one_epoch( # ) tb_writer.add_image( "train/speech_hat_", - np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))), + np.array( + Image.open(plot_curve(speech_hat_i, params.sampling_rate)) + ), params.batch_idx_train, dataformats="HWC", ) tb_writer.add_image( "train/speech_", - np.array(Image.open(plot_curve(speech_i, params.sampling_rate))), + np.array( + Image.open(plot_curve(speech_i, params.sampling_rate)) + ), params.batch_idx_train, dataformats="HWC", ) @@ -724,17 +732,20 @@ def train_one_epoch( # ) tb_writer.add_image( f"train/valid_speech_hat_{index}", - np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))), + np.array( + Image.open(plot_curve(speech_hat_i, params.sampling_rate)) + ), params.batch_idx_train, dataformats="HWC", ) tb_writer.add_image( f"train/valid_speech_{index}", - np.array(Image.open(plot_curve(speech_i, params.sampling_rate))), + np.array( + Image.open(plot_curve(speech_i, params.sampling_rate)) + ), params.batch_idx_train, dataformats="HWC", ) - loss_value = tot_loss["generator_loss"] / tot_loss["samples"] params.train_loss = loss_value diff --git a/egs/tokenizer/CODEC/encodec/utils.py b/egs/tokenizer/CODEC/encodec/utils.py index 77cfd3c50..d9869d595 100644 --- a/egs/tokenizer/CODEC/encodec/utils.py +++ b/egs/tokenizer/CODEC/encodec/utils.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections +import io import logging from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -118,14 +119,19 @@ def plot_feature(spectrogram): plt.close() return data -def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes: - import io + +def plot_curve( + speech: torch.Tensor, sampling_rate: int, figsize: List[int] = [10, 22] +) -> bytes: import matplotlib.pyplot as plt import numpy as np + assert len(figsize) == 2, "figsize should be a list of two integers" + plt.figure() plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T) + plt.rcParams["figure.figsize"] = figsize buf = io.BytesIO() plt.savefig(buf, format="jpeg") buf.seek(0)