Merge 752e16be1038211bada5d4f15eb4b59d3f6ae9f6 into c401a2646b347bf1fff0c2ce1a4ee13b0f482448

This commit is contained in:
Erwan Zerhouni 2024-01-26 16:24:46 +08:00 committed by GitHub
commit 6c98fbc309
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 594 additions and 196 deletions

View File

@ -0,0 +1,177 @@
#!/usr/bin/env python3
#
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang,
# Zengwei Yao,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from icefall.utils import make_pad_mask
NON_BLANK_THRES = 0.9
class FrameReducer(nn.Module):
"""The encoder output is first used to calculate
the CTC posterior probability; then for each output frame,
if its blank posterior is bigger than some thresholds,
it will be simply discarded from the encoder output.
"""
def __init__(
self,
):
super().__init__()
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
ctc_output: torch.Tensor,
y_lens: Optional[torch.Tensor] = None,
blank_id: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The shared encoder output with shape [N, T, C].
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
ctc_output:
The CTC output with shape [N, T, vocab_size].
y_lens:
A tensor of shape (batch_size,) containing the number of frames in
`y` before padding.
blank_id:
The blank id of ctc_output.
Returns:
out:
The frame reduced encoder output with shape [N, T', C].
out_lens:
A tensor of shape (batch_size,) containing the number of frames in
`out` before padding.
"""
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
~padding_mask
)
if y_lens is not None or self.training is False:
# Limit the maximum number of reduced frames
if y_lens is not None:
limit_lens = T - y_lens
else:
# In eval mode, ensure audio that is completely silent does not make any errors
limit_lens = T - torch.ones_like(x_lens)
max_limit_len = limit_lens.max().int()
fake_limit_indexes = torch.topk(
ctc_output[:, :, blank_id], max_limit_len
).indices
_T = (
torch.arange(max_limit_len)
.expand_as(
fake_limit_indexes,
)
.to(device=x.device)
)
_T = torch.remainder(_T, limit_lens.unsqueeze(1))
limit_indexes = torch.gather(fake_limit_indexes, 1, _T)
limit_mask = (
torch.full_like(
non_blank_mask,
0,
device=x.device,
).scatter_(1, limit_indexes, 1)
== 1
)
non_blank_mask = non_blank_mask | ~limit_mask
out_lens = non_blank_mask.sum(dim=1)
max_len = out_lens.max()
pad_lens_list = (
torch.full_like(
out_lens,
max_len.item(),
device=x.device,
)
- out_lens
)
max_pad_len = int(pad_lens_list.max().item())
out = F.pad(x, (0, 0, 0, max_pad_len))
valid_pad_mask = ~make_pad_mask(pad_lens_list)
total_valid_mask = torch.concat([non_blank_mask, valid_pad_mask], dim=1)
out = out[total_valid_mask].reshape(N, -1, C)
return out, out_lens
if __name__ == "__main__":
import time
test_times = 10000
device = "cuda:0"
frame_reducer = FrameReducer()
# non zero case
x = torch.ones(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.log(
torch.randn(15, 498, 500, dtype=torch.float32, device=device),
)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)
# all zero case
x = torch.zeros(15, 498, 384, dtype=torch.float32, device=device)
x_lens = torch.tensor([498] * 15, dtype=torch.int64, device=device)
y_lens = torch.tensor([150] * 15, dtype=torch.int64, device=device)
ctc_output = torch.zeros(15, 498, 500, dtype=torch.float32, device=device)
avg_time = 0
for i in range(test_times):
torch.cuda.synchronize(device=x.device)
delta_time = time.time()
x_fr, x_lens_fr = frame_reducer(x, x_lens, ctc_output, y_lens)
torch.cuda.synchronize(device=x.device)
delta_time = time.time() - delta_time
avg_time += delta_time
print(x_fr.shape)
print(x_lens_fr)
print(avg_time / test_times)

View File

@ -0,0 +1,113 @@
# Copyright 2022 Xiaomi Corp. (authors: Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from acoustic_model.utils_py.scaling_zipformer import Balancer, ScaledConv1d
class LConv(nn.Module):
"""A convolution module to prevent information loss."""
def __init__(
self,
channels: int,
kernel_size: int = 7,
bias: bool = True,
):
"""
Args:
channels:
Dimension of the input embedding, and of the lconv output.
"""
super().__init__()
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.deriv_balancer1 = Balancer(
2 * channels,
channel_dim=1,
min_abs=0.05,
max_abs=10.0,
min_positive=0.05,
max_positive=1.0,
)
self.depthwise_conv = nn.Conv1d(
2 * channels,
2 * channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=2 * channels,
bias=bias,
)
self.deriv_balancer2 = Balancer(
2 * channels,
channel_dim=1,
min_positive=0.05,
max_positive=1.0,
min_abs=0.05,
max_abs=20.0,
)
self.pointwise_conv2 = ScaledConv1d(
2 * channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
initial_scale=0.05,
)
def forward(
self,
x: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: A 3-D tensor of shape (N, T, C).
Returns:
Return a tensor of shape (N, T, C).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(0, 2, 1) # (#batch, channels, time).
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
x = self.deriv_balancer1(x)
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
x = self.pointwise_conv2(x) # (batch, channels, time)
return x.permute(0, 2, 1)

View File

@ -16,16 +16,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple import warnings
from typing import List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos, encode_supervisions, make_pad_mask
class AsrModel(nn.Module): class AsrModel(nn.Module):
def __init__( def __init__(
@ -34,11 +35,14 @@ class AsrModel(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: Optional[nn.Module] = None, decoder: Optional[nn.Module] = None,
joiner: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None,
lconv: Optional[nn.Module] = None,
frame_reducer: Optional[nn.Module] = None,
encoder_dim: int = 384, encoder_dim: int = 384,
decoder_dim: int = 512, decoder_dim: int = 512,
vocab_size: int = 500, vocab_size: int = 500,
use_transducer: bool = True, use_transducer: bool = True,
use_ctc: bool = False, use_ctc: bool = False,
use_bs: bool = True,
): ):
"""A joint CTC & Transducer ASR model. """A joint CTC & Transducer ASR model.
@ -77,6 +81,10 @@ class AsrModel(nn.Module):
use_transducer or use_ctc use_transducer or use_ctc
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert (
(use_ctc and use_bs) or (use_ctc and not use_bs) or not (use_ctc and use_bs)
), "Blank Skip needs CTC"
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed self.encoder_embed = encoder_embed
@ -111,6 +119,11 @@ class AsrModel(nn.Module):
nn.LogSoftmax(dim=-1), nn.LogSoftmax(dim=-1),
) )
self.use_bs = use_bs
if self.use_bs:
self.lconv = lconv
self.frame_reducer = frame_reducer
def forward_encoder( def forward_encoder(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
@ -146,8 +159,13 @@ class AsrModel(nn.Module):
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
targets: torch.Tensor, targets: List[int],
target_lengths: torch.Tensor, target_lengths: torch.Tensor,
supervisions: dict,
subsampling_factor: int,
ctc_beam_size: int,
reduction: str = "sum",
warmup: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute CTC loss. """Compute CTC loss.
Args: Args:
@ -158,18 +176,60 @@ class AsrModel(nn.Module):
targets: targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension. to be un-padded and concatenated within 1 dimension.
supervisions:
Dict into a pair of torch Tensor, and a list of transcription strings or token indexes
reduction:
Specifies the reduction to apply to the output
""" """
# Compute CTC log-prob # Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C) ctc_output = self.ctc_output(encoder_out) # (N, T, C)
encoder_out_fr = encoder_out
encoder_out_lens_fr = encoder_out_lens
ctc_loss = torch.nn.functional.ctc_loss( if self.use_bs and warmup >= 2.0:
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) # lconv
targets=targets, encoder_out = self.lconv(
input_lengths=encoder_out_lens, x=encoder_out,
target_lengths=target_lengths, src_key_padding_mask=make_pad_mask(encoder_out_lens),
reduction="sum",
) )
return ctc_loss
# frame reduce
encoder_out_fr, encoder_out_lens_fr = self.frame_reducer(
encoder_out,
encoder_out_lens,
ctc_output,
target_lengths,
self.decoder.blank_id,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=subsampling_factor,
token_ids=targets,
)
# TODO: Crash without this line
supervision_segments = supervision_segments.to("cpu")
decoding_graph = k2.ctc_graph(
token_ids, modified=False, device=encoder_out.device
)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=ctc_beam_size,
reduction=reduction,
use_double_scores=True,
)
return ctc_loss, encoder_out_fr, encoder_out_lens_fr
def forward_transducer( def forward_transducer(
self, self,
@ -180,6 +240,8 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
delay_penalty: float = 0.0,
reduction: str = "sum",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute Transducer loss. """Compute Transducer loss.
Args: Args:
@ -199,6 +261,8 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
reduction:
Specifies the reduction to apply to the output
""" """
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
@ -226,11 +290,6 @@ class AsrModel(nn.Module):
lm = self.simple_lm_proj(decoder_out) lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out) am = self.simple_am_proj(encoder_out)
# if self.training and random.random() < 0.25:
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
# if self.training and random.random() < 0.25:
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(), lm=lm.float(),
@ -240,7 +299,8 @@ class AsrModel(nn.Module):
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
am_only_scale=am_scale, am_only_scale=am_scale,
boundary=boundary, boundary=boundary,
reduction="sum", delay_penalty=delay_penalty,
reduction=reduction,
return_grad=True, return_grad=True,
) )
@ -273,7 +333,8 @@ class AsrModel(nn.Module):
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,
boundary=boundary, boundary=boundary,
reduction="sum", delay_penalty=delay_penalty,
reduction=reduction,
) )
return simple_loss, pruned_loss return simple_loss, pruned_loss
@ -283,9 +344,15 @@ class AsrModel(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
supervisions: dict,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
subsampling_factor: int = 4,
ctc_beam_size: int = 10,
delay_penalty: float = 0.0,
reduction: str = "sum",
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -306,10 +373,11 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
reduction:
Specifies the reduction to apply to the output
Returns: Returns:
Return the transducer losses and CTC loss, Return the transducer losses and CTC loss,
in form of (simple_loss, pruned_loss, ctc_loss) in form of (simple_loss, pruned_loss, ctc_loss)
Note: Note:
Regarding am_scale & lm_scale, it will make the loss-function one of Regarding am_scale & lm_scale, it will make the loss-function one of
the form: the form:
@ -320,7 +388,7 @@ class AsrModel(nn.Module):
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) assert x.size(0) == x_lens.size(0) == y.dim0
# Compute encoder outputs # Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
@ -328,6 +396,22 @@ class AsrModel(nn.Module):
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
if self.use_ctc:
# Compute CTC loss
ctc_loss, encoder_out, encoder_out_lens = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=y.tolist(),
target_lengths=y_lens,
supervisions=supervisions,
subsampling_factor=subsampling_factor,
ctc_beam_size=ctc_beam_size,
reduction=reduction,
warmup=warmup,
)
else:
ctc_loss = torch.empty(0, device=encoder_out.device)
if self.use_transducer: if self.use_transducer:
# Compute transducer loss # Compute transducer loss
simple_loss, pruned_loss = self.forward_transducer( simple_loss, pruned_loss = self.forward_transducer(
@ -338,21 +422,11 @@ class AsrModel(nn.Module):
prune_range=prune_range, prune_range=prune_range,
am_scale=am_scale, am_scale=am_scale,
lm_scale=lm_scale, lm_scale=lm_scale,
reduction=reduction,
delay_penalty=delay_penalty,
) )
else: else:
simple_loss = torch.empty(0) simple_loss = torch.empty(0)
pruned_loss = torch.empty(0) pruned_loss = torch.empty(0)
if self.use_ctc:
# Compute CTC loss
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
else:
ctc_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss return simple_loss, pruned_loss, ctc_loss

View File

@ -67,7 +67,9 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder from decoder import Decoder
from frame_reducer import FrameReducer
from joiner import Joiner from joiner import Joiner
from lconv import LConv
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
@ -258,6 +260,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use CTC head.", help="If True, use CTC head.",
) )
parser.add_argument(
"--use-bs",
type=str2bool,
default=False,
help="If True, use blank-skip.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -529,6 +538,7 @@ def get_params() -> AttributeDict:
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer # parameters for zipformer
"ctc_beam_size": 10,
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed. "subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 2000, "warm_step": 2000,
@ -583,6 +593,16 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
return encoder return encoder
def get_lconv(params: AttributeDict) -> nn.Module:
lconv = LConv(channels=max(params.encoder_dim))
return lconv
def get_frame_reducer(params: AttributeDict) -> nn.Module:
frame_reducer = FrameReducer()
return frame_reducer
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -620,16 +640,24 @@ def get_model(params: AttributeDict) -> nn.Module:
decoder = None decoder = None
joiner = None joiner = None
lconv, frame_reducer = None, None
if params.use_bs:
lconv = get_lconv(params)
frame_reducer = get_frame_reducer(params)
model = AsrModel( model = AsrModel(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
lconv=lconv,
frame_reducer=frame_reducer,
encoder_dim=max(_to_int_tuple(params.encoder_dim)), encoder_dim=max(_to_int_tuple(params.encoder_dim)),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
use_transducer=params.use_transducer, use_transducer=params.use_transducer,
use_ctc=params.use_ctc, use_ctc=params.use_ctc,
use_bs=params.use_bs,
) )
return model return model
@ -756,6 +784,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute loss given the model and its inputs. Compute loss given the model and its inputs.
@ -796,9 +825,12 @@ def compute_loss(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
supervisions=supervisions,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
ctc_beam_size=params.ctc_beam_size,
warmup=warmup,
) )
loss = 0.0 loss = 0.0
@ -859,6 +891,7 @@ def compute_validation_loss(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=False, is_training=False,
warmup=(params.batch_idx_train / params.warm_step),
) )
assert loss.requires_grad is False assert loss.requires_grad is False
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info
@ -953,6 +986,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=(params.batch_idx_train / params.warm_step),
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info