diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index b8f7d6336..ab216a2ae 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import random from typing import Optional, Tuple @@ -29,6 +30,21 @@ from scaling import ScaledLinear from icefall.utils import add_sos, make_pad_mask, time_warp +@contextlib.contextmanager +def fork_rng(cpu_state, cuda_state, rng_state, device): + with torch.random.fork_rng(devices=[device]): + torch.set_rng_state(cpu_state) + torch.cuda.set_rng_state(cuda_state, device) + + rng_state2 = random.getstate() + random.setstate(rng_state) + + try: + yield + finally: + random.setstate(rng_state2) + + class AsrModel(nn.Module): def __init__( self, @@ -191,15 +207,13 @@ class AsrModel(nn.Module): ) if model_prev: - with torch.random.fork_rng(devices=[device]): - torch.set_rng_state(cpu_state) - torch.cuda.set_rng_state(cuda_state, device) - - rng_state2 = random.getstate() - random.setstate(rng_state) - + with fork_rng( + cpu_state=cpu_state, + cuda_state=cuda_state, + rng_state=rng_state, + device=device, + ): ctc_output_prev = model_prev.ctc_output(encoder_out) - random.setstate(rng_state2) print( "ctc_output_prev", ctc_output_prev.detach().mean(), @@ -477,17 +491,15 @@ class AsrModel(nn.Module): ) if model_prev: - with torch.random.fork_rng(devices=[device]): - torch.set_rng_state(cpu_state) - torch.cuda.set_rng_state(cuda_state, device) - - rng_state2 = random.getstate() - random.setstate(rng_state) - + with fork_rng( + cpu_state=cpu_state, + cuda_state=cuda_state, + rng_state=rng_state, + device=device, + ): encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder( x, x_lens ) - random.setstate(rng_state2) print( "encoder_out_prev", encoder_out_prev.detach().mean(),