minor updates

* black formatted
* customize fig size
This commit is contained in:
zr_jin 2024-10-23 00:44:10 +08:00
parent fe27d2ca36
commit aa9132c82c
2 changed files with 29 additions and 12 deletions

View File

@ -329,11 +329,15 @@ def get_model(params: AttributeDict) -> nn.Module:
params.update(inference_params) params.update(inference_params)
hop_length = np.prod(params.ratios) hop_length = np.prod(params.ratios)
n_q = int( n_q = (
1000 int(
* params.target_bandwidths[-1] 1000
// (math.ceil(params.sampling_rate / hop_length) * 10) * params.target_bandwidths[-1]
) if params.n_q is None else params.n_q // (math.ceil(params.sampling_rate / hop_length) * 10)
)
if params.n_q is None
else params.n_q
)
encoder = SEANetEncoder( encoder = SEANetEncoder(
n_filters=params.generator_n_filters, n_filters=params.generator_n_filters,
@ -660,13 +664,17 @@ def train_one_epoch(
# ) # )
tb_writer.add_image( tb_writer.add_image(
"train/speech_hat_", "train/speech_hat_",
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))), np.array(
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
),
params.batch_idx_train, params.batch_idx_train,
dataformats="HWC", dataformats="HWC",
) )
tb_writer.add_image( tb_writer.add_image(
"train/speech_", "train/speech_",
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))), np.array(
Image.open(plot_curve(speech_i, params.sampling_rate))
),
params.batch_idx_train, params.batch_idx_train,
dataformats="HWC", dataformats="HWC",
) )
@ -724,17 +732,20 @@ def train_one_epoch(
# ) # )
tb_writer.add_image( tb_writer.add_image(
f"train/valid_speech_hat_{index}", f"train/valid_speech_hat_{index}",
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))), np.array(
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
),
params.batch_idx_train, params.batch_idx_train,
dataformats="HWC", dataformats="HWC",
) )
tb_writer.add_image( tb_writer.add_image(
f"train/valid_speech_{index}", f"train/valid_speech_{index}",
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))), np.array(
Image.open(plot_curve(speech_i, params.sampling_rate))
),
params.batch_idx_train, params.batch_idx_train,
dataformats="HWC", dataformats="HWC",
) )
loss_value = tot_loss["generator_loss"] / tot_loss["samples"] loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
params.train_loss = loss_value params.train_loss = loss_value

View File

@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import io
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -118,14 +119,19 @@ def plot_feature(spectrogram):
plt.close() plt.close()
return data return data
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
import io def plot_curve(
speech: torch.Tensor, sampling_rate: int, figsize: List[int] = [10, 22]
) -> bytes:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
assert len(figsize) == 2, "figsize should be a list of two integers"
plt.figure() plt.figure()
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T) plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
plt.rcParams["figure.figsize"] = figsize
buf = io.BytesIO() buf = io.BytesIO()
plt.savefig(buf, format="jpeg") plt.savefig(buf, format="jpeg")
buf.seek(0) buf.seek(0)