minor updates

* black formatted
* customize fig size
This commit is contained in:
zr_jin 2024-10-23 00:44:10 +08:00
parent fe27d2ca36
commit aa9132c82c
2 changed files with 29 additions and 12 deletions

View File

@ -329,11 +329,15 @@ def get_model(params: AttributeDict) -> nn.Module:
params.update(inference_params)
hop_length = np.prod(params.ratios)
n_q = int(
1000
* params.target_bandwidths[-1]
// (math.ceil(params.sampling_rate / hop_length) * 10)
) if params.n_q is None else params.n_q
n_q = (
int(
1000
* params.target_bandwidths[-1]
// (math.ceil(params.sampling_rate / hop_length) * 10)
)
if params.n_q is None
else params.n_q
)
encoder = SEANetEncoder(
n_filters=params.generator_n_filters,
@ -660,13 +664,17 @@ def train_one_epoch(
# )
tb_writer.add_image(
"train/speech_hat_",
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
np.array(
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
"train/speech_",
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
np.array(
Image.open(plot_curve(speech_i, params.sampling_rate))
),
params.batch_idx_train,
dataformats="HWC",
)
@ -724,17 +732,20 @@ def train_one_epoch(
# )
tb_writer.add_image(
f"train/valid_speech_hat_{index}",
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
np.array(
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
),
params.batch_idx_train,
dataformats="HWC",
)
tb_writer.add_image(
f"train/valid_speech_{index}",
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
np.array(
Image.open(plot_curve(speech_i, params.sampling_rate))
),
params.batch_idx_train,
dataformats="HWC",
)
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import io
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
@ -118,14 +119,19 @@ def plot_feature(spectrogram):
plt.close()
return data
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
import io
def plot_curve(
speech: torch.Tensor, sampling_rate: int, figsize: List[int] = [10, 22]
) -> bytes:
import matplotlib.pyplot as plt
import numpy as np
assert len(figsize) == 2, "figsize should be a list of two integers"
plt.figure()
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
plt.rcParams["figure.figsize"] = figsize
buf = io.BytesIO()
plt.savefig(buf, format="jpeg")
buf.seek(0)