mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 03:22:19 +00:00
Use contextmanager to manage rng state.
This commit is contained in:
parent
97df1ce3eb
commit
1e986c930d
@ -16,6 +16,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
@ -29,6 +30,21 @@ from scaling import ScaledLinear
|
|||||||
from icefall.utils import add_sos, make_pad_mask, time_warp
|
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def fork_rng(cpu_state, cuda_state, rng_state, device):
|
||||||
|
with torch.random.fork_rng(devices=[device]):
|
||||||
|
torch.set_rng_state(cpu_state)
|
||||||
|
torch.cuda.set_rng_state(cuda_state, device)
|
||||||
|
|
||||||
|
rng_state2 = random.getstate()
|
||||||
|
random.setstate(rng_state)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
random.setstate(rng_state2)
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -191,15 +207,13 @@ class AsrModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_prev:
|
if model_prev:
|
||||||
with torch.random.fork_rng(devices=[device]):
|
with fork_rng(
|
||||||
torch.set_rng_state(cpu_state)
|
cpu_state=cpu_state,
|
||||||
torch.cuda.set_rng_state(cuda_state, device)
|
cuda_state=cuda_state,
|
||||||
|
rng_state=rng_state,
|
||||||
rng_state2 = random.getstate()
|
device=device,
|
||||||
random.setstate(rng_state)
|
):
|
||||||
|
|
||||||
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
||||||
random.setstate(rng_state2)
|
|
||||||
print(
|
print(
|
||||||
"ctc_output_prev",
|
"ctc_output_prev",
|
||||||
ctc_output_prev.detach().mean(),
|
ctc_output_prev.detach().mean(),
|
||||||
@ -477,17 +491,15 @@ class AsrModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_prev:
|
if model_prev:
|
||||||
with torch.random.fork_rng(devices=[device]):
|
with fork_rng(
|
||||||
torch.set_rng_state(cpu_state)
|
cpu_state=cpu_state,
|
||||||
torch.cuda.set_rng_state(cuda_state, device)
|
cuda_state=cuda_state,
|
||||||
|
rng_state=rng_state,
|
||||||
rng_state2 = random.getstate()
|
device=device,
|
||||||
random.setstate(rng_state)
|
):
|
||||||
|
|
||||||
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||||
x, x_lens
|
x, x_lens
|
||||||
)
|
)
|
||||||
random.setstate(rng_state2)
|
|
||||||
print(
|
print(
|
||||||
"encoder_out_prev",
|
"encoder_out_prev",
|
||||||
encoder_out_prev.detach().mean(),
|
encoder_out_prev.detach().mean(),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user