Use contextmanager to manage rng state.

This commit is contained in:
Fangjun Kuang 2024-10-30 19:33:21 +08:00
parent 97df1ce3eb
commit 1e986c930d

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple
@ -29,6 +30,21 @@ from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask, time_warp from icefall.utils import add_sos, make_pad_mask, time_warp
@contextlib.contextmanager
def fork_rng(cpu_state, cuda_state, rng_state, device):
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
try:
yield
finally:
random.setstate(rng_state2)
class AsrModel(nn.Module): class AsrModel(nn.Module):
def __init__( def __init__(
self, self,
@ -191,15 +207,13 @@ class AsrModel(nn.Module):
) )
if model_prev: if model_prev:
with torch.random.fork_rng(devices=[device]): with fork_rng(
torch.set_rng_state(cpu_state) cpu_state=cpu_state,
torch.cuda.set_rng_state(cuda_state, device) cuda_state=cuda_state,
rng_state=rng_state,
rng_state2 = random.getstate() device=device,
random.setstate(rng_state) ):
ctc_output_prev = model_prev.ctc_output(encoder_out) ctc_output_prev = model_prev.ctc_output(encoder_out)
random.setstate(rng_state2)
print( print(
"ctc_output_prev", "ctc_output_prev",
ctc_output_prev.detach().mean(), ctc_output_prev.detach().mean(),
@ -477,17 +491,15 @@ class AsrModel(nn.Module):
) )
if model_prev: if model_prev:
with torch.random.fork_rng(devices=[device]): with fork_rng(
torch.set_rng_state(cpu_state) cpu_state=cpu_state,
torch.cuda.set_rng_state(cuda_state, device) cuda_state=cuda_state,
rng_state=rng_state,
rng_state2 = random.getstate() device=device,
random.setstate(rng_state) ):
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder( encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
x, x_lens x, x_lens
) )
random.setstate(rng_state2)
print( print(
"encoder_out_prev", "encoder_out_prev",
encoder_out_prev.detach().mean(), encoder_out_prev.detach().mean(),