mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 03:22:19 +00:00
Use contextmanager to manage rng state.
This commit is contained in:
parent
97df1ce3eb
commit
1e986c930d
@ -16,6 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@ -29,6 +30,21 @@ from scaling import ScaledLinear
|
||||
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork_rng(cpu_state, cuda_state, rng_state, device):
|
||||
with torch.random.fork_rng(devices=[device]):
|
||||
torch.set_rng_state(cpu_state)
|
||||
torch.cuda.set_rng_state(cuda_state, device)
|
||||
|
||||
rng_state2 = random.getstate()
|
||||
random.setstate(rng_state)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
random.setstate(rng_state2)
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -191,15 +207,13 @@ class AsrModel(nn.Module):
|
||||
)
|
||||
|
||||
if model_prev:
|
||||
with torch.random.fork_rng(devices=[device]):
|
||||
torch.set_rng_state(cpu_state)
|
||||
torch.cuda.set_rng_state(cuda_state, device)
|
||||
|
||||
rng_state2 = random.getstate()
|
||||
random.setstate(rng_state)
|
||||
|
||||
with fork_rng(
|
||||
cpu_state=cpu_state,
|
||||
cuda_state=cuda_state,
|
||||
rng_state=rng_state,
|
||||
device=device,
|
||||
):
|
||||
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
||||
random.setstate(rng_state2)
|
||||
print(
|
||||
"ctc_output_prev",
|
||||
ctc_output_prev.detach().mean(),
|
||||
@ -477,17 +491,15 @@ class AsrModel(nn.Module):
|
||||
)
|
||||
|
||||
if model_prev:
|
||||
with torch.random.fork_rng(devices=[device]):
|
||||
torch.set_rng_state(cpu_state)
|
||||
torch.cuda.set_rng_state(cuda_state, device)
|
||||
|
||||
rng_state2 = random.getstate()
|
||||
random.setstate(rng_state)
|
||||
|
||||
with fork_rng(
|
||||
cpu_state=cpu_state,
|
||||
cuda_state=cuda_state,
|
||||
rng_state=rng_state,
|
||||
device=device,
|
||||
):
|
||||
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||
x, x_lens
|
||||
)
|
||||
random.setstate(rng_state2)
|
||||
print(
|
||||
"encoder_out_prev",
|
||||
encoder_out_prev.detach().mean(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user