mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Fix devices
This commit is contained in:
parent
a266922678
commit
df39fc6783
@ -259,7 +259,7 @@ def compute_beta_prods(Psum, Ptop):
|
||||
# Shape is (*, K)
|
||||
S1 = Psum.unsqueeze(-1) - Ptop_cum_shift
|
||||
|
||||
temp = torch.arange(K, -1, -1) # [K, K-1, ..., 0]
|
||||
temp = torch.arange(K, -1, -1, device=Psum.device) # [K, K-1, ..., 0]
|
||||
# Kk, of shape (K,), contains [K, K-1, ..., 1], representing K-k for k = [0, 1, ..., K-1]
|
||||
Kk = temp[0:K]
|
||||
# Kk1 of shape (K,), contains [K-1, K-2, ..., 0], representing K-k-1 for k = [0, 1, ..., K-1]
|
||||
@ -637,7 +637,7 @@ def sample_combined_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tens
|
||||
# rand_values are random in {0, 1, ..., B-1}
|
||||
rand = torch.randint((2**63 - 1), B.shape) % B
|
||||
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
|
||||
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K)
|
||||
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
|
||||
|
||||
shifted_samples = compute_shifted_samples(combined_cumsums_mod,
|
||||
delta_P,
|
||||
@ -794,7 +794,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
|
||||
assert torch.all((s * inv_s) % M == 1) # if this fails, check that M is a power of 2
|
||||
|
||||
# R = pseudo-random re-ordering of p.
|
||||
R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M)) % M),
|
||||
R = torch.minimum(torch.gather(P, dim=-1, index=(s * torch.arange(M, device=P.device)) % M),
|
||||
B)
|
||||
# S = inclusive-sum of R
|
||||
S = torch.cumsum(R, dim=-1)
|
||||
@ -802,7 +802,7 @@ def soft_sample_forward(p: Tensor, K: int, input_is_log: bool) -> Tuple[Tensor,
|
||||
# Let b be a random integer drawn uniformly from {0, 1, ..., B-1}.
|
||||
b = torch.randint((2**63 - 1), B.shape) % B
|
||||
|
||||
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1), S[...,:-1]), dim=-1)
|
||||
S_prev = torch.cat((torch.zeros(*S.shape[:-1], 1, device=S.device), S[...,:-1]), dim=-1)
|
||||
|
||||
k_prev = (S_prev + b) // B
|
||||
k_cur = (S + b) // B
|
||||
@ -1060,7 +1060,7 @@ def _test_combined():
|
||||
# rand_values are random in {0, 1, ..., B-1}
|
||||
rand = torch.randint((2**63 - 1), B.shape) % B
|
||||
# rand, rand + B, rand + 2B, ...., rand + (K-1)B
|
||||
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K)
|
||||
samples = rand.unsqueeze(-1) + B.unsqueeze(-1) * torch.arange(K, device=B.device)
|
||||
print("rand = ", rand)
|
||||
print("sampled = ", samples)
|
||||
|
||||
@ -1150,6 +1150,10 @@ def _test_knowledge_base_lookup():
|
||||
print("x.grad = ", x.grad)
|
||||
print("knowlege_base.grad norm = ", knowledge_base.grad.norm())
|
||||
|
||||
|
||||
#train_pairs = [ (torch.randn(B, T, E), torch.randn(B, T, E)) for _ in range(100) ]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_sample_combined()
|
||||
_test_sample_combined_mean()
|
||||
|
Loading…
x
Reference in New Issue
Block a user