Save rng states.

This commit is contained in:
Fangjun Kuang 2024-10-30 19:21:46 +08:00
parent db2d9a6001
commit 97df1ce3eb
2 changed files with 130 additions and 3 deletions

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Optional, Tuple
import k2
@ -159,6 +160,9 @@ class AsrModel(nn.Module):
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
encoder_out_prev: Optional[torch.Tensor] = None,
encoder_out_lens_prev: Optional[torch.Tensor] = None,
model_prev=None,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
@ -170,8 +174,43 @@ class AsrModel(nn.Module):
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
device = encoder_out.device
if model_prev:
cpu_state = torch.get_rng_state()
cuda_state = torch.cuda.get_rng_state(device)
rng_state = random.getstate()
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
print(
"ctc_output",
ctc_output.detach().mean(),
ctc_output.detach().sum(),
ctc_output.detach().min(),
ctc_output.detach().max(),
)
if model_prev:
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
ctc_output_prev = model_prev.ctc_output(encoder_out)
random.setstate(rng_state2)
print(
"ctc_output_prev",
ctc_output_prev.detach().mean(),
ctc_output_prev.detach().sum(),
ctc_output_prev.detach().min(),
ctc_output_prev.detach().max(),
)
print(
"isclose ctc",
(ctc_output - ctc_output).detach().abs().max(),
)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
@ -345,6 +384,7 @@ class AsrModel(nn.Module):
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
model_prev=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
@ -418,9 +458,53 @@ class AsrModel(nn.Module):
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
device = x.device
if model_prev:
cpu_state = torch.get_rng_state()
cuda_state = torch.cuda.get_rng_state(device)
rng_state = random.getstate()
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
print(
"encoder_out",
encoder_out.detach().mean(),
encoder_out.detach().abs().max(),
encoder_out.detach().abs().min(),
encoder_out.detach().sum(),
encoder_out.shape,
)
if model_prev:
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
x, x_lens
)
random.setstate(rng_state2)
print(
"encoder_out_prev",
encoder_out_prev.detach().mean(),
encoder_out_prev.detach().abs().max(),
encoder_out_prev.detach().abs().mean(),
encoder_out_prev.detach().sum(),
encoder_out_prev.shape,
)
print(
"isclose",
(encoder_out - encoder_out_prev).detach().abs().max(),
(encoder_out_lens - encoder_out_lens_prev).detach().abs().max(),
)
else:
encoder_out_prev = None
encoder_out_lens_prev = None
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
@ -451,6 +535,9 @@ class AsrModel(nn.Module):
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
encoder_out_prev=encoder_out_prev,
encoder_out_lens_prev=encoder_out_lens_prev,
model_prev=model_prev,
)
cr_loss = torch.empty(0)
else:

View File

@ -549,6 +549,14 @@ def get_parser():
help="Whether to use bf16 in AMP.",
)
parser.add_argument(
"--limit-grad-start-batch",
type=int,
# default=1000,
default=2,
help="Limit grad starting from this batch.",
)
add_model_arguments(parser)
return parser
@ -879,6 +887,7 @@ def compute_loss(
batch: dict,
is_training: bool,
spec_augment: Optional[SpecAugment] = None,
model_prev: Union[nn.Module, DDP] = None,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
@ -942,6 +951,7 @@ def compute_loss(
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
model_prev=model_prev,
)
loss = 0.0
@ -1037,6 +1047,7 @@ def train_one_epoch(
scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None,
model_prev: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -1104,9 +1115,14 @@ def train_one_epoch(
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
if params.batch_idx_train > params.limit_grad_start_batch:
model_prev = copy.deepcopy(model)
loss, loss_info = compute_loss(
params=params,
model=model,
model_prev=model_prev
if params.batch_idx_train > params.limit_grad_start_batch
else None,
sp=sp,
batch=batch,
is_training=True,
@ -1123,6 +1139,19 @@ def train_one_epoch(
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if params.batch_idx_train >= params.limit_grad_start_batch:
if model_prev is None:
model_prev = copy.deepcopy(model)
else:
model_prev = copy.deepcopy(model)
print(
"here",
params.batch_idx_train,
params.limit_grad_start_batch,
model_prev is None,
)
except Exception as e:
logging.info(f"Caught exception: {e}.")
save_bad_model()
@ -1208,7 +1237,7 @@ def train_one_epoch(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if batch_idx % params.valid_interval == 1000 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -1233,6 +1262,8 @@ def train_one_epoch(
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
return model_prev
def run(rank, world_size, args):
"""
@ -1319,6 +1350,9 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0
model_avg = copy.deepcopy(model).to(torch.float64)
model_prev: Optional[nn.Module] = None
# TODO(fangjun): load checkpoint for model_prev
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg
@ -1428,7 +1462,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if False and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
@ -1453,10 +1487,11 @@ def run(rank, world_size, args):
params.cur_epoch = epoch
train_one_epoch(
model_prev = train_one_epoch(
params=params,
model=model,
model_avg=model_avg,
model_prev=model_prev,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
@ -1587,4 +1622,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
# torch.use_deterministic_algorithms(True, warn_only=True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = False
main()