mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Save rng states.
This commit is contained in:
parent
db2d9a6001
commit
97df1ce3eb
@ -16,6 +16,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -159,6 +160,9 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
targets: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
target_lengths: torch.Tensor,
|
target_lengths: torch.Tensor,
|
||||||
|
encoder_out_prev: Optional[torch.Tensor] = None,
|
||||||
|
encoder_out_lens_prev: Optional[torch.Tensor] = None,
|
||||||
|
model_prev=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute CTC loss.
|
"""Compute CTC loss.
|
||||||
Args:
|
Args:
|
||||||
@ -170,8 +174,43 @@ class AsrModel(nn.Module):
|
|||||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
to be un-padded and concatenated within 1 dimension.
|
to be un-padded and concatenated within 1 dimension.
|
||||||
"""
|
"""
|
||||||
|
device = encoder_out.device
|
||||||
|
if model_prev:
|
||||||
|
cpu_state = torch.get_rng_state()
|
||||||
|
cuda_state = torch.cuda.get_rng_state(device)
|
||||||
|
rng_state = random.getstate()
|
||||||
|
|
||||||
# Compute CTC log-prob
|
# Compute CTC log-prob
|
||||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
print(
|
||||||
|
"ctc_output",
|
||||||
|
ctc_output.detach().mean(),
|
||||||
|
ctc_output.detach().sum(),
|
||||||
|
ctc_output.detach().min(),
|
||||||
|
ctc_output.detach().max(),
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_prev:
|
||||||
|
with torch.random.fork_rng(devices=[device]):
|
||||||
|
torch.set_rng_state(cpu_state)
|
||||||
|
torch.cuda.set_rng_state(cuda_state, device)
|
||||||
|
|
||||||
|
rng_state2 = random.getstate()
|
||||||
|
random.setstate(rng_state)
|
||||||
|
|
||||||
|
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
||||||
|
random.setstate(rng_state2)
|
||||||
|
print(
|
||||||
|
"ctc_output_prev",
|
||||||
|
ctc_output_prev.detach().mean(),
|
||||||
|
ctc_output_prev.detach().sum(),
|
||||||
|
ctc_output_prev.detach().min(),
|
||||||
|
ctc_output_prev.detach().max(),
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"isclose ctc",
|
||||||
|
(ctc_output - ctc_output).detach().abs().max(),
|
||||||
|
)
|
||||||
|
|
||||||
ctc_loss = torch.nn.functional.ctc_loss(
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
@ -345,6 +384,7 @@ class AsrModel(nn.Module):
|
|||||||
spec_augment: Optional[SpecAugment] = None,
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
supervision_segments: Optional[torch.Tensor] = None,
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
time_warp_factor: Optional[int] = 80,
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
model_prev=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -418,9 +458,53 @@ class AsrModel(nn.Module):
|
|||||||
x_lens = x_lens.repeat(2)
|
x_lens = x_lens.repeat(2)
|
||||||
y = k2.ragged.cat([y, y], axis=0)
|
y = k2.ragged.cat([y, y], axis=0)
|
||||||
|
|
||||||
|
device = x.device
|
||||||
|
if model_prev:
|
||||||
|
cpu_state = torch.get_rng_state()
|
||||||
|
cuda_state = torch.cuda.get_rng_state(device)
|
||||||
|
rng_state = random.getstate()
|
||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"encoder_out",
|
||||||
|
encoder_out.detach().mean(),
|
||||||
|
encoder_out.detach().abs().max(),
|
||||||
|
encoder_out.detach().abs().min(),
|
||||||
|
encoder_out.detach().sum(),
|
||||||
|
encoder_out.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_prev:
|
||||||
|
with torch.random.fork_rng(devices=[device]):
|
||||||
|
torch.set_rng_state(cpu_state)
|
||||||
|
torch.cuda.set_rng_state(cuda_state, device)
|
||||||
|
|
||||||
|
rng_state2 = random.getstate()
|
||||||
|
random.setstate(rng_state)
|
||||||
|
|
||||||
|
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||||
|
x, x_lens
|
||||||
|
)
|
||||||
|
random.setstate(rng_state2)
|
||||||
|
print(
|
||||||
|
"encoder_out_prev",
|
||||||
|
encoder_out_prev.detach().mean(),
|
||||||
|
encoder_out_prev.detach().abs().max(),
|
||||||
|
encoder_out_prev.detach().abs().mean(),
|
||||||
|
encoder_out_prev.detach().sum(),
|
||||||
|
encoder_out_prev.shape,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"isclose",
|
||||||
|
(encoder_out - encoder_out_prev).detach().abs().max(),
|
||||||
|
(encoder_out_lens - encoder_out_lens_prev).detach().abs().max(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out_prev = None
|
||||||
|
encoder_out_lens_prev = None
|
||||||
|
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
@ -451,6 +535,9 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
|
encoder_out_prev=encoder_out_prev,
|
||||||
|
encoder_out_lens_prev=encoder_out_lens_prev,
|
||||||
|
model_prev=model_prev,
|
||||||
)
|
)
|
||||||
cr_loss = torch.empty(0)
|
cr_loss = torch.empty(0)
|
||||||
else:
|
else:
|
||||||
|
@ -549,6 +549,14 @@ def get_parser():
|
|||||||
help="Whether to use bf16 in AMP.",
|
help="Whether to use bf16 in AMP.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit-grad-start-batch",
|
||||||
|
type=int,
|
||||||
|
# default=1000,
|
||||||
|
default=2,
|
||||||
|
help="Limit grad starting from this batch.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -879,6 +887,7 @@ def compute_loss(
|
|||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
spec_augment: Optional[SpecAugment] = None,
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
|
model_prev: Union[nn.Module, DDP] = None,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute loss given the model and its inputs.
|
Compute loss given the model and its inputs.
|
||||||
@ -942,6 +951,7 @@ def compute_loss(
|
|||||||
spec_augment=spec_augment,
|
spec_augment=spec_augment,
|
||||||
supervision_segments=supervision_segments,
|
supervision_segments=supervision_segments,
|
||||||
time_warp_factor=params.spec_aug_time_warp_factor,
|
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||||
|
model_prev=model_prev,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -1037,6 +1047,7 @@ def train_one_epoch(
|
|||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
spec_augment: Optional[SpecAugment] = None,
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
|
model_prev: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -1104,9 +1115,14 @@ def train_one_epoch(
|
|||||||
with torch.cuda.amp.autocast(
|
with torch.cuda.amp.autocast(
|
||||||
enabled=params.use_autocast, dtype=params.dtype
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
):
|
):
|
||||||
|
if params.batch_idx_train > params.limit_grad_start_batch:
|
||||||
|
model_prev = copy.deepcopy(model)
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
model_prev=model_prev
|
||||||
|
if params.batch_idx_train > params.limit_grad_start_batch
|
||||||
|
else None,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
@ -1123,6 +1139,19 @@ def train_one_epoch(
|
|||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if params.batch_idx_train >= params.limit_grad_start_batch:
|
||||||
|
if model_prev is None:
|
||||||
|
model_prev = copy.deepcopy(model)
|
||||||
|
else:
|
||||||
|
model_prev = copy.deepcopy(model)
|
||||||
|
print(
|
||||||
|
"here",
|
||||||
|
params.batch_idx_train,
|
||||||
|
params.limit_grad_start_batch,
|
||||||
|
model_prev is None,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"Caught exception: {e}.")
|
logging.info(f"Caught exception: {e}.")
|
||||||
save_bad_model()
|
save_bad_model()
|
||||||
@ -1208,7 +1237,7 @@ def train_one_epoch(
|
|||||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
if batch_idx % params.valid_interval == 1000 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1233,6 +1262,8 @@ def train_one_epoch(
|
|||||||
params.best_train_epoch = params.cur_epoch
|
params.best_train_epoch = params.cur_epoch
|
||||||
params.best_train_loss = params.train_loss
|
params.best_train_loss = params.train_loss
|
||||||
|
|
||||||
|
return model_prev
|
||||||
|
|
||||||
|
|
||||||
def run(rank, world_size, args):
|
def run(rank, world_size, args):
|
||||||
"""
|
"""
|
||||||
@ -1319,6 +1350,9 @@ def run(rank, world_size, args):
|
|||||||
# model_avg is only used with rank 0
|
# model_avg is only used with rank 0
|
||||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||||
|
|
||||||
|
model_prev: Optional[nn.Module] = None
|
||||||
|
# TODO(fangjun): load checkpoint for model_prev
|
||||||
|
|
||||||
assert params.start_epoch > 0, params.start_epoch
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
params=params, model=model, model_avg=model_avg
|
params=params, model=model, model_avg=model_avg
|
||||||
@ -1428,7 +1462,7 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
if False and not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
@ -1453,10 +1487,11 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
params.cur_epoch = epoch
|
params.cur_epoch = epoch
|
||||||
|
|
||||||
train_one_epoch(
|
model_prev = train_one_epoch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
|
model_prev=model_prev,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
@ -1587,4 +1622,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
# torch.backends.cudnn.deterministic = True
|
||||||
|
# torch.backends.cudnn.benchmark = False
|
||||||
|
# torch.backends.cudnn.enabled = False
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user