mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Save rng states.
This commit is contained in:
parent
db2d9a6001
commit
97df1ce3eb
@ -16,6 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
@ -159,6 +160,9 @@ class AsrModel(nn.Module):
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
encoder_out_prev: Optional[torch.Tensor] = None,
|
||||
encoder_out_lens_prev: Optional[torch.Tensor] = None,
|
||||
model_prev=None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
@ -170,8 +174,43 @@ class AsrModel(nn.Module):
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
device = encoder_out.device
|
||||
if model_prev:
|
||||
cpu_state = torch.get_rng_state()
|
||||
cuda_state = torch.cuda.get_rng_state(device)
|
||||
rng_state = random.getstate()
|
||||
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
print(
|
||||
"ctc_output",
|
||||
ctc_output.detach().mean(),
|
||||
ctc_output.detach().sum(),
|
||||
ctc_output.detach().min(),
|
||||
ctc_output.detach().max(),
|
||||
)
|
||||
|
||||
if model_prev:
|
||||
with torch.random.fork_rng(devices=[device]):
|
||||
torch.set_rng_state(cpu_state)
|
||||
torch.cuda.set_rng_state(cuda_state, device)
|
||||
|
||||
rng_state2 = random.getstate()
|
||||
random.setstate(rng_state)
|
||||
|
||||
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
||||
random.setstate(rng_state2)
|
||||
print(
|
||||
"ctc_output_prev",
|
||||
ctc_output_prev.detach().mean(),
|
||||
ctc_output_prev.detach().sum(),
|
||||
ctc_output_prev.detach().min(),
|
||||
ctc_output_prev.detach().max(),
|
||||
)
|
||||
print(
|
||||
"isclose ctc",
|
||||
(ctc_output - ctc_output).detach().abs().max(),
|
||||
)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
@ -345,6 +384,7 @@ class AsrModel(nn.Module):
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
supervision_segments: Optional[torch.Tensor] = None,
|
||||
time_warp_factor: Optional[int] = 80,
|
||||
model_prev=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -418,9 +458,53 @@ class AsrModel(nn.Module):
|
||||
x_lens = x_lens.repeat(2)
|
||||
y = k2.ragged.cat([y, y], axis=0)
|
||||
|
||||
device = x.device
|
||||
if model_prev:
|
||||
cpu_state = torch.get_rng_state()
|
||||
cuda_state = torch.cuda.get_rng_state(device)
|
||||
rng_state = random.getstate()
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
print(
|
||||
"encoder_out",
|
||||
encoder_out.detach().mean(),
|
||||
encoder_out.detach().abs().max(),
|
||||
encoder_out.detach().abs().min(),
|
||||
encoder_out.detach().sum(),
|
||||
encoder_out.shape,
|
||||
)
|
||||
|
||||
if model_prev:
|
||||
with torch.random.fork_rng(devices=[device]):
|
||||
torch.set_rng_state(cpu_state)
|
||||
torch.cuda.set_rng_state(cuda_state, device)
|
||||
|
||||
rng_state2 = random.getstate()
|
||||
random.setstate(rng_state)
|
||||
|
||||
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||
x, x_lens
|
||||
)
|
||||
random.setstate(rng_state2)
|
||||
print(
|
||||
"encoder_out_prev",
|
||||
encoder_out_prev.detach().mean(),
|
||||
encoder_out_prev.detach().abs().max(),
|
||||
encoder_out_prev.detach().abs().mean(),
|
||||
encoder_out_prev.detach().sum(),
|
||||
encoder_out_prev.shape,
|
||||
)
|
||||
print(
|
||||
"isclose",
|
||||
(encoder_out - encoder_out_prev).detach().abs().max(),
|
||||
(encoder_out_lens - encoder_out_lens_prev).detach().abs().max(),
|
||||
)
|
||||
else:
|
||||
encoder_out_prev = None
|
||||
encoder_out_lens_prev = None
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
@ -451,6 +535,9 @@ class AsrModel(nn.Module):
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
encoder_out_prev=encoder_out_prev,
|
||||
encoder_out_lens_prev=encoder_out_lens_prev,
|
||||
model_prev=model_prev,
|
||||
)
|
||||
cr_loss = torch.empty(0)
|
||||
else:
|
||||
|
@ -549,6 +549,14 @@ def get_parser():
|
||||
help="Whether to use bf16 in AMP.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--limit-grad-start-batch",
|
||||
type=int,
|
||||
# default=1000,
|
||||
default=2,
|
||||
help="Limit grad starting from this batch.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -879,6 +887,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
model_prev: Union[nn.Module, DDP] = None,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute loss given the model and its inputs.
|
||||
@ -942,6 +951,7 @@ def compute_loss(
|
||||
spec_augment=spec_augment,
|
||||
supervision_segments=supervision_segments,
|
||||
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||
model_prev=model_prev,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
@ -1037,6 +1047,7 @@ def train_one_epoch(
|
||||
scaler: GradScaler,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
model_prev: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -1104,9 +1115,14 @@ def train_one_epoch(
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
if params.batch_idx_train > params.limit_grad_start_batch:
|
||||
model_prev = copy.deepcopy(model)
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
model_prev=model_prev
|
||||
if params.batch_idx_train > params.limit_grad_start_batch
|
||||
else None,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
@ -1123,6 +1139,19 @@ def train_one_epoch(
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.batch_idx_train >= params.limit_grad_start_batch:
|
||||
if model_prev is None:
|
||||
model_prev = copy.deepcopy(model)
|
||||
else:
|
||||
model_prev = copy.deepcopy(model)
|
||||
print(
|
||||
"here",
|
||||
params.batch_idx_train,
|
||||
params.limit_grad_start_batch,
|
||||
model_prev is None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.info(f"Caught exception: {e}.")
|
||||
save_bad_model()
|
||||
@ -1208,7 +1237,7 @@ def train_one_epoch(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if batch_idx % params.valid_interval == 1000 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -1233,6 +1262,8 @@ def train_one_epoch(
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
return model_prev
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
@ -1319,6 +1350,9 @@ def run(rank, world_size, args):
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model).to(torch.float64)
|
||||
|
||||
model_prev: Optional[nn.Module] = None
|
||||
# TODO(fangjun): load checkpoint for model_prev
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
@ -1428,7 +1462,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if False and not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
@ -1453,10 +1487,11 @@ def run(rank, world_size, args):
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
model_prev = train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
model_prev=model_prev,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
@ -1587,4 +1622,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.enabled = False
|
||||
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user