Use contextmanager to manage rng state.

This commit is contained in:
Fangjun Kuang 2024-10-30 19:33:21 +08:00
parent 97df1ce3eb
commit 1e986c930d

View File

@ -16,6 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import random
from typing import Optional, Tuple
@ -29,6 +30,21 @@ from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask, time_warp
@contextlib.contextmanager
def fork_rng(cpu_state, cuda_state, rng_state, device):
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
try:
yield
finally:
random.setstate(rng_state2)
class AsrModel(nn.Module):
def __init__(
self,
@ -191,15 +207,13 @@ class AsrModel(nn.Module):
)
if model_prev:
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
with fork_rng(
cpu_state=cpu_state,
cuda_state=cuda_state,
rng_state=rng_state,
device=device,
):
ctc_output_prev = model_prev.ctc_output(encoder_out)
random.setstate(rng_state2)
print(
"ctc_output_prev",
ctc_output_prev.detach().mean(),
@ -477,17 +491,15 @@ class AsrModel(nn.Module):
)
if model_prev:
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
with fork_rng(
cpu_state=cpu_state,
cuda_state=cuda_state,
rng_state=rng_state,
device=device,
):
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
x, x_lens
)
random.setstate(rng_state2)
print(
"encoder_out_prev",
encoder_out_prev.detach().mean(),