mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
minor updates
* black formatted * customize fig size
This commit is contained in:
parent
fe27d2ca36
commit
aa9132c82c
@ -329,11 +329,15 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
params.update(inference_params)
|
||||
|
||||
hop_length = np.prod(params.ratios)
|
||||
n_q = int(
|
||||
1000
|
||||
* params.target_bandwidths[-1]
|
||||
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
||||
) if params.n_q is None else params.n_q
|
||||
n_q = (
|
||||
int(
|
||||
1000
|
||||
* params.target_bandwidths[-1]
|
||||
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
||||
)
|
||||
if params.n_q is None
|
||||
else params.n_q
|
||||
)
|
||||
|
||||
encoder = SEANetEncoder(
|
||||
n_filters=params.generator_n_filters,
|
||||
@ -660,13 +664,17 @@ def train_one_epoch(
|
||||
# )
|
||||
tb_writer.add_image(
|
||||
"train/speech_hat_",
|
||||
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
||||
np.array(
|
||||
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
|
||||
),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
"train/speech_",
|
||||
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
||||
np.array(
|
||||
Image.open(plot_curve(speech_i, params.sampling_rate))
|
||||
),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
@ -724,17 +732,20 @@ def train_one_epoch(
|
||||
# )
|
||||
tb_writer.add_image(
|
||||
f"train/valid_speech_hat_{index}",
|
||||
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
||||
np.array(
|
||||
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
|
||||
),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
tb_writer.add_image(
|
||||
f"train/valid_speech_{index}",
|
||||
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
||||
np.array(
|
||||
Image.open(plot_curve(speech_i, params.sampling_rate))
|
||||
),
|
||||
params.batch_idx_train,
|
||||
dataformats="HWC",
|
||||
)
|
||||
|
||||
|
||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||
params.train_loss = loss_value
|
||||
|
@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
@ -118,14 +119,19 @@ def plot_feature(spectrogram):
|
||||
plt.close()
|
||||
return data
|
||||
|
||||
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
|
||||
import io
|
||||
|
||||
def plot_curve(
|
||||
speech: torch.Tensor, sampling_rate: int, figsize: List[int] = [10, 22]
|
||||
) -> bytes:
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
assert len(figsize) == 2, "figsize should be a list of two integers"
|
||||
|
||||
plt.figure()
|
||||
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
|
||||
plt.rcParams["figure.figsize"] = figsize
|
||||
buf = io.BytesIO()
|
||||
plt.savefig(buf, format="jpeg")
|
||||
buf.seek(0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user