mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
minor updates
* black formatted * customize fig size
This commit is contained in:
parent
fe27d2ca36
commit
aa9132c82c
@ -329,11 +329,15 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
params.update(inference_params)
|
params.update(inference_params)
|
||||||
|
|
||||||
hop_length = np.prod(params.ratios)
|
hop_length = np.prod(params.ratios)
|
||||||
n_q = int(
|
n_q = (
|
||||||
1000
|
int(
|
||||||
* params.target_bandwidths[-1]
|
1000
|
||||||
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
* params.target_bandwidths[-1]
|
||||||
) if params.n_q is None else params.n_q
|
// (math.ceil(params.sampling_rate / hop_length) * 10)
|
||||||
|
)
|
||||||
|
if params.n_q is None
|
||||||
|
else params.n_q
|
||||||
|
)
|
||||||
|
|
||||||
encoder = SEANetEncoder(
|
encoder = SEANetEncoder(
|
||||||
n_filters=params.generator_n_filters,
|
n_filters=params.generator_n_filters,
|
||||||
@ -660,13 +664,17 @@ def train_one_epoch(
|
|||||||
# )
|
# )
|
||||||
tb_writer.add_image(
|
tb_writer.add_image(
|
||||||
"train/speech_hat_",
|
"train/speech_hat_",
|
||||||
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
np.array(
|
||||||
|
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
|
||||||
|
),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
dataformats="HWC",
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
tb_writer.add_image(
|
tb_writer.add_image(
|
||||||
"train/speech_",
|
"train/speech_",
|
||||||
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
np.array(
|
||||||
|
Image.open(plot_curve(speech_i, params.sampling_rate))
|
||||||
|
),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
dataformats="HWC",
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
@ -724,17 +732,20 @@ def train_one_epoch(
|
|||||||
# )
|
# )
|
||||||
tb_writer.add_image(
|
tb_writer.add_image(
|
||||||
f"train/valid_speech_hat_{index}",
|
f"train/valid_speech_hat_{index}",
|
||||||
np.array(Image.open(plot_curve(speech_hat_i, params.sampling_rate))),
|
np.array(
|
||||||
|
Image.open(plot_curve(speech_hat_i, params.sampling_rate))
|
||||||
|
),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
dataformats="HWC",
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
tb_writer.add_image(
|
tb_writer.add_image(
|
||||||
f"train/valid_speech_{index}",
|
f"train/valid_speech_{index}",
|
||||||
np.array(Image.open(plot_curve(speech_i, params.sampling_rate))),
|
np.array(
|
||||||
|
Image.open(plot_curve(speech_i, params.sampling_rate))
|
||||||
|
),
|
||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
dataformats="HWC",
|
dataformats="HWC",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
loss_value = tot_loss["generator_loss"] / tot_loss["samples"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections
|
import collections
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
@ -118,14 +119,19 @@ def plot_feature(spectrogram):
|
|||||||
plt.close()
|
plt.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def plot_curve(speech: torch.Tensor, sampling_rate: int) -> bytes:
|
|
||||||
import io
|
def plot_curve(
|
||||||
|
speech: torch.Tensor, sampling_rate: int, figsize: List[int] = [10, 22]
|
||||||
|
) -> bytes:
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
assert len(figsize) == 2, "figsize should be a list of two integers"
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
|
plt.plot(np.arange(sampling_rate) / sampling_rate, speech.detach().cpu().numpy().T)
|
||||||
|
plt.rcParams["figure.figsize"] = figsize
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
plt.savefig(buf, format="jpeg")
|
plt.savefig(buf, format="jpeg")
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user