mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 00:24:19 +00:00
Merge branch 'k2-fsa:master' into fix/css_baseline
This commit is contained in:
commit
370b839a43
@ -635,7 +635,6 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -800,7 +799,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -872,7 +872,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1045,7 +1045,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1028,7 +1028,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1019,7 +1019,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -730,7 +730,6 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -919,7 +918,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -908,7 +908,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -635,7 +635,6 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -800,7 +799,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -999,7 +999,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -988,7 +988,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1019,7 +1019,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1074,7 +1074,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1075,7 +1075,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -557,7 +557,6 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
@ -953,7 +953,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -953,7 +953,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -955,7 +955,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -43,6 +43,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
# This function is copied from lhotse
|
# This function is copied from lhotse
|
||||||
def tqdm_urlretrieve_hook(t):
|
def tqdm_urlretrieve_hook(t):
|
||||||
"""Wraps tqdm instance.
|
"""Wraps tqdm instance.
|
||||||
|
@ -236,7 +236,7 @@ def greedy_search_batch(
|
|||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -507,7 +507,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
@ -162,7 +162,6 @@ def merge_chunks(
|
|||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
|
||||||
for cut in cuts_chunk:
|
for cut in cuts_chunk:
|
||||||
cur_rec_id = cut.recording.id
|
cur_rec_id = cut.recording.id
|
||||||
if len(cut_list) == 0:
|
if len(cut_list) == 0:
|
||||||
|
@ -264,6 +264,7 @@ def decode_dataset(
|
|||||||
- timestamps of reference transcript
|
- timestamps of reference transcript
|
||||||
- timestamps of predicted result
|
- timestamps of predicted result
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Background worker to add alignemnt and save cuts to disk.
|
# Background worker to add alignemnt and save cuts to disk.
|
||||||
def _save_worker(
|
def _save_worker(
|
||||||
cuts: List[Cut],
|
cuts: List[Cut],
|
||||||
|
@ -66,7 +66,6 @@ class Eve(Optimizer):
|
|||||||
weight_decay=1e-3,
|
weight_decay=1e-3,
|
||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -811,7 +811,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -719,7 +719,7 @@ def greedy_search_batch(
|
|||||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -1019,7 +1019,7 @@ def modified_beam_search(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
@ -2608,7 +2608,6 @@ def modified_beam_search_LODR(
|
|||||||
context_score = 0
|
context_score = 0
|
||||||
new_context_state = None if context_graph is None else hyp.context_state
|
new_context_state = None if context_graph is None else hyp.context_state
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
if context_graph is not None:
|
if context_graph is not None:
|
||||||
(
|
(
|
||||||
context_score,
|
context_score,
|
||||||
@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end] # get batch
|
current_encoder_out = encoder_out.data[start:end] # get batch
|
||||||
@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion(
|
|||||||
new_token = topk_token_indexes[k]
|
new_token = topk_token_indexes[k]
|
||||||
new_timestamp = hyp.timestamp[:]
|
new_timestamp = hyp.timestamp[:]
|
||||||
if new_token not in (blank_id, unk_id):
|
if new_token not in (blank_id, unk_id):
|
||||||
|
|
||||||
ys.append(new_token)
|
ys.append(new_token)
|
||||||
new_timestamp.append(t)
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
@ -66,7 +66,6 @@ class Eve(Optimizer):
|
|||||||
weight_decay=1e-3,
|
weight_decay=1e-3,
|
||||||
target_rms=0.1,
|
target_rms=0.1,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
|
@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
return
|
return
|
||||||
|
|
||||||
with torch.cuda.device_of(first_fw):
|
with torch.cuda.device_of(first_fw):
|
||||||
|
|
||||||
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
||||||
# an inplace operation on self._flat_weights
|
# an inplace operation on self._flat_weights
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -1003,7 +1003,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -56,7 +56,6 @@ class CodebookIndexExtractor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params: AttributeDict):
|
def __init__(self, params: AttributeDict):
|
||||||
|
|
||||||
self.params = params
|
self.params = params
|
||||||
params.subsets = ["clean-100"]
|
params.subsets = ["clean-100"]
|
||||||
if self.params.full_libri:
|
if self.params.full_libri:
|
||||||
|
@ -111,7 +111,7 @@ def batch_force_alignment(
|
|||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
finalized_B = []
|
finalized_B = []
|
||||||
for (t, batch_size) in enumerate(batch_size_list):
|
for t, batch_size in enumerate(batch_size_list):
|
||||||
start = offset
|
start = offset
|
||||||
end = offset + batch_size
|
end = offset + batch_size
|
||||||
current_encoder_out = encoder_out.data[start:end]
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
@ -1132,7 +1132,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ class BatchedOptimizer(Optimizer):
|
|||||||
|
|
||||||
yield tuples # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
parameters_names=None,
|
parameters_names=None,
|
||||||
show_dominant_parameters=True,
|
show_dominant_parameters=True,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert parameters_names is not None, (
|
assert parameters_names is not None, (
|
||||||
"Please prepare parameters_names,"
|
"Please prepare parameters_names,"
|
||||||
"which is a List[List[str]]. Each List[str] is for a group"
|
"which is a List[List[str]]. Each List[str] is for a group"
|
||||||
@ -224,9 +223,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch = True
|
batch = True
|
||||||
|
|
||||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
|
|
||||||
with self.batched_params(group["params"], group_params_names) as batches:
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
# a stacking dim, it is not a real dim.
|
# a stacking dim, it is not a real dim.
|
||||||
@ -325,7 +322,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for (p, state, param_names) in tuples:
|
for p, state, param_names in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -410,7 +407,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
from tuples, we still pass it to save some time.
|
from tuples, we still pass it to save some time.
|
||||||
"""
|
"""
|
||||||
all_sumsq_orig = {}
|
all_sumsq_orig = {}
|
||||||
for (p, state, batch_param_names) in tuples:
|
for p, state, batch_param_names in tuples:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
@ -426,7 +423,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
for name, sumsq_orig, rms, grad in zip(
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||||
):
|
):
|
||||||
|
|
||||||
proportion_orig = sumsq_orig / tot_sumsq
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int):
|
|||||||
|
|
||||||
# if epoch == 130:
|
# if epoch == 130:
|
||||||
# opts = diagnostics.TensorDiagnosticOptions(
|
# opts = diagnostics.TensorDiagnosticOptions(
|
||||||
# 2 ** 22
|
# 512
|
||||||
# ) # allow 4 megabytes per sub-module
|
# ) # allow 4 megabytes per sub-module
|
||||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||||
|
|
||||||
|
@ -1028,7 +1028,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1052,7 +1052,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1042,7 +1042,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1029,7 +1029,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1030,7 +1030,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1141,7 +1141,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1154,7 +1154,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -230,7 +230,9 @@ class Conformer(Transformer):
|
|||||||
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
) # (T, B, F)
|
) # (T, B, F)
|
||||||
else:
|
else:
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)
|
x = self.encoder(
|
||||||
|
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, B, F)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
@ -543,7 +543,6 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
@ -463,7 +463,6 @@ def train_one_epoch(
|
|||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||||
)
|
)
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
@ -513,7 +513,6 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
@ -517,7 +517,6 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
|
@ -61,10 +61,15 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
# the balancers are to avoid any drift in the magnitude of the
|
# the balancers are to avoid any drift in the magnitude of the
|
||||||
# embeddings, which would interact badly with parameter averaging.
|
# embeddings, which would interact badly with parameter averaging.
|
||||||
self.balancer = Balancer(decoder_dim, channel_dim=-1,
|
self.balancer = Balancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
decoder_dim,
|
||||||
min_abs=0.5, max_abs=1.0,
|
channel_dim=-1,
|
||||||
prob=0.05)
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.5,
|
||||||
|
max_abs=1.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
|
||||||
@ -81,10 +86,15 @@ class Decoder(nn.Module):
|
|||||||
groups=decoder_dim // 4, # group size == 4
|
groups=decoder_dim // 4, # group size == 4
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
|
self.balancer2 = Balancer(
|
||||||
min_positive=0.0, max_positive=1.0,
|
decoder_dim,
|
||||||
min_abs=0.5, max_abs=1.0,
|
channel_dim=-1,
|
||||||
prob=0.05)
|
min_positive=0.0,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.5,
|
||||||
|
max_abs=1.0,
|
||||||
|
prob=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -107,9 +117,7 @@ class Decoder(nn.Module):
|
|||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embedding_out = F.pad(
|
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||||
embedding_out, pad=(self.context_size - 1, 0)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
|
@ -52,12 +52,13 @@ class Joiner(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, s_range, C).
|
Return a tensor of shape (N, T, s_range, C).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
|
assert encoder_out.ndim == decoder_out.ndim, (
|
||||||
|
encoder_out.shape,
|
||||||
|
decoder_out.shape,
|
||||||
|
)
|
||||||
|
|
||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
decoder_out
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
|
@ -303,7 +303,9 @@ def main():
|
|||||||
|
|
||||||
for test_set, test_dl in zip(test_sets, test_dl):
|
for test_set, test_dl in zip(test_sets, test_dl):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
|
results, total_duration = decode_dataset(
|
||||||
|
dl=test_dl, model=model, token_table=token_table
|
||||||
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_seconds = end_time - start_time
|
elapsed_seconds = end_time - start_time
|
||||||
rtf = elapsed_seconds / total_duration
|
rtf = elapsed_seconds / total_duration
|
||||||
|
@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
|
|||||||
|
|
||||||
yield tuples # <-- calling code will do the actual optimization here!
|
yield tuples # <-- calling code will do the actual optimization here!
|
||||||
|
|
||||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||||
for i, p in enumerate(batch): # batch is list of Parameter
|
for i, p in enumerate(batch): # batch is list of Parameter
|
||||||
p.copy_(stacked_params[i])
|
p.copy_(stacked_params[i])
|
||||||
|
|
||||||
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
size_update_period=4,
|
size_update_period=4,
|
||||||
clipping_update_period=100,
|
clipping_update_period=100,
|
||||||
):
|
):
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
clipping_scale=clipping_scale,
|
clipping_scale=clipping_scale,
|
||||||
@ -299,8 +298,8 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# the input is groups of parameter or named parameter.
|
# the input is groups of parameter or named parameter.
|
||||||
for cur_group in iterable_or_groups:
|
for cur_group in iterable_or_groups:
|
||||||
assert "named_params" in cur_group
|
assert "named_params" in cur_group
|
||||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
name_list = [x[0] for x in cur_group["named_params"]]
|
||||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
p_list = [x[1] for x in cur_group["named_params"]]
|
||||||
del cur_group["named_params"]
|
del cur_group["named_params"]
|
||||||
cur_group["params"] = p_list
|
cur_group["params"] = p_list
|
||||||
param_groups.append(cur_group)
|
param_groups.append(cur_group)
|
||||||
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
batch = True
|
batch = True
|
||||||
|
|
||||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||||
|
|
||||||
with self.batched_params(group["params"], group_params_names) as batches:
|
with self.batched_params(group["params"], group_params_names) as batches:
|
||||||
|
|
||||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||||
# a stacking dim, it is not a real dim.
|
# a stacking dim, it is not a real dim.
|
||||||
@ -428,7 +425,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
clipping_update_period = group["clipping_update_period"]
|
clipping_update_period = group["clipping_update_period"]
|
||||||
|
|
||||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||||
for (p, state, param_names) in tuples:
|
for p, state, param_names in tuples:
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
if grad.is_sparse:
|
if grad.is_sparse:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -513,7 +510,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
from tuples, we still pass it to save some time.
|
from tuples, we still pass it to save some time.
|
||||||
"""
|
"""
|
||||||
all_sumsq_orig = {}
|
all_sumsq_orig = {}
|
||||||
for (p, state, batch_param_names) in tuples:
|
for p, state, batch_param_names in tuples:
|
||||||
# p is a stacked batch parameters.
|
# p is a stacked batch parameters.
|
||||||
batch_grad = p.grad
|
batch_grad = p.grad
|
||||||
if p.numel() == p.shape[0]: # a batch of scalars
|
if p.numel() == p.shape[0]: # a batch of scalars
|
||||||
@ -529,7 +526,6 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
for name, sumsq_orig, rms, grad in zip(
|
for name, sumsq_orig, rms, grad in zip(
|
||||||
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
||||||
):
|
):
|
||||||
|
|
||||||
proportion_orig = sumsq_orig / tot_sumsq
|
proportion_orig = sumsq_orig / tot_sumsq
|
||||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||||
|
|
||||||
@ -667,8 +663,7 @@ class ScaledAdam(BatchedOptimizer):
|
|||||||
# We have to look at the trained model for parameters at or around the
|
# We have to look at the trained model for parameters at or around the
|
||||||
# param_max_rms, because sometimes they can indicate a problem with the
|
# param_max_rms, because sometimes they can indicate a problem with the
|
||||||
# topology or settings.
|
# topology or settings.
|
||||||
scale_step = torch.minimum(scale_step,
|
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||||
(param_max_rms - param_rms) / param_rms)
|
|
||||||
|
|
||||||
delta = state["delta"]
|
delta = state["delta"]
|
||||||
# the factor of (1-beta1) relates to momentum.
|
# the factor of (1-beta1) relates to momentum.
|
||||||
@ -879,7 +874,8 @@ class Eden(LRScheduler):
|
|||||||
warmup_factor = (
|
warmup_factor = (
|
||||||
1.0
|
1.0
|
||||||
if self.batch >= self.warmup_batches
|
if self.batch >= self.warmup_batches
|
||||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
else self.warmup_start
|
||||||
|
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1111,7 +1107,7 @@ def _test_scaled_adam(hidden_dim: int):
|
|||||||
|
|
||||||
# if epoch == 130:
|
# if epoch == 130:
|
||||||
# opts = diagnostics.TensorDiagnosticOptions(
|
# opts = diagnostics.TensorDiagnosticOptions(
|
||||||
# 2 ** 22
|
# 512
|
||||||
# ) # allow 4 megabytes per sub-module
|
# ) # allow 4 megabytes per sub-module
|
||||||
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
||||||
|
|
||||||
|
@ -100,17 +100,13 @@ class Model(nn.Module):
|
|||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder_proj = encoder_proj
|
self.encoder_proj = encoder_proj
|
||||||
|
|
||||||
def forward(
|
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
self, feature: Tensor, feature_lens: Tensor
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
x, x_lens, src_key_padding_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
logits = self.encoder_proj(encoder_out)
|
logits = self.encoder_proj(encoder_out)
|
||||||
@ -168,9 +164,7 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = (
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
||||||
)
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
|||||||
)
|
)
|
||||||
batch_states.append(cached_embed_left_pad)
|
batch_states.append(cached_embed_left_pad)
|
||||||
|
|
||||||
processed_lens = torch.cat(
|
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||||
[state_list[i][-1] for i in range(batch_size)], dim=0
|
|
||||||
)
|
|
||||||
batch_states.append(processed_lens)
|
batch_states.append(processed_lens)
|
||||||
|
|
||||||
return batch_states
|
return batch_states
|
||||||
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|||||||
for layer in range(tot_num_layers):
|
for layer in range(tot_num_layers):
|
||||||
layer_offset = layer * 6
|
layer_offset = layer * 6
|
||||||
# cached_key: (left_context_len, batch_size, key_dim)
|
# cached_key: (left_context_len, batch_size, key_dim)
|
||||||
cached_key_list = batch_states[layer_offset].chunk(
|
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||||
chunks=batch_size, dim=1
|
|
||||||
)
|
|
||||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||||
chunks=batch_size, dim=1
|
chunks=batch_size, dim=1
|
||||||
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
|||||||
cached_conv2_list[i],
|
cached_conv2_list[i],
|
||||||
]
|
]
|
||||||
|
|
||||||
cached_embed_left_pad_list = batch_states[-2].chunk(
|
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||||
chunks=batch_size, dim=0
|
|
||||||
)
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
state_list[i].append(cached_embed_left_pad_list[i])
|
state_list[i].append(cached_embed_left_pad_list[i])
|
||||||
|
|
||||||
@ -380,11 +374,7 @@ def streaming_forward(
|
|||||||
Returns encoder outputs, output lengths, and updated states.
|
Returns encoder outputs, output lengths, and updated states.
|
||||||
"""
|
"""
|
||||||
cached_embed_left_pad = states[-2]
|
cached_embed_left_pad = states[-2]
|
||||||
(
|
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||||
x,
|
|
||||||
x_lens,
|
|
||||||
new_cached_embed_left_pad,
|
|
||||||
) = model.encoder_embed.streaming_forward(
|
|
||||||
x=features,
|
x=features,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
cached_left_pad=cached_embed_left_pad,
|
cached_left_pad=cached_embed_left_pad,
|
||||||
@ -404,9 +394,7 @@ def streaming_forward(
|
|||||||
new_processed_lens = processed_lens + x_lens
|
new_processed_lens = processed_lens + x_lens
|
||||||
|
|
||||||
# (batch, left_context_size + chunk_size)
|
# (batch, left_context_size + chunk_size)
|
||||||
src_key_padding_mask = torch.cat(
|
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||||
[processed_mask, src_key_padding_mask], dim=1
|
|
||||||
)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
encoder_states = states[:-2]
|
encoder_states = states[:-2]
|
||||||
@ -494,9 +482,7 @@ def decode_one_chunk(
|
|||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
greedy_search(
|
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||||
model=model, encoder_out=encoder_out, streams=decode_streams
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
processed_lens = torch.tensor(processed_lens, device=device)
|
processed_lens = torch.tensor(processed_lens, device=device)
|
||||||
processed_lens = processed_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
@ -517,9 +503,7 @@ def decode_one_chunk(
|
|||||||
num_active_paths=params.num_active_paths,
|
num_active_paths=params.num_active_paths,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
|
|
||||||
states = unstack_states(new_states)
|
states = unstack_states(new_states)
|
||||||
|
|
||||||
@ -577,9 +561,7 @@ def decode_dataset(
|
|||||||
decode_streams = []
|
decode_streams = []
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
initial_states = get_init_states(
|
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||||
model=model, batch_size=1, device=device
|
|
||||||
)
|
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
cut_id=cut.id,
|
cut_id=cut.id,
|
||||||
@ -649,9 +631,7 @@ def decode_dataset(
|
|||||||
elif params.decoding_method == "modified_beam_search":
|
elif params.decoding_method == "modified_beam_search":
|
||||||
key = f"num_active_paths_{params.num_active_paths}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
f"Unsupported decoding method: {params.decoding_method}"
|
|
||||||
)
|
|
||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
@ -684,8 +664,7 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -718,9 +697,7 @@ def main():
|
|||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
assert params.causal, params.causal
|
assert params.causal, params.causal
|
||||||
assert (
|
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||||
"," not in params.chunk_size
|
|
||||||
), "chunk_size should be one value in decoding."
|
|
||||||
assert (
|
assert (
|
||||||
"," not in params.left_context_frames
|
"," not in params.left_context_frames
|
||||||
), "left_context_frames should be one value in decoding."
|
), "left_context_frames should be one value in decoding."
|
||||||
@ -760,9 +737,9 @@ def main():
|
|||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg
|
||||||
)[: params.avg]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -789,9 +766,9 @@ def main():
|
|||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
params.exp_dir, iteration=-params.iter
|
: params.avg + 1
|
||||||
)[: params.avg + 1]
|
]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
|
@ -107,9 +107,7 @@ class ConvNeXt(nn.Module):
|
|||||||
if layerdrop_rate != 0.0:
|
if layerdrop_rate != 0.0:
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
mask = (
|
mask = (
|
||||||
torch.rand(
|
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
> layerdrop_rate
|
> layerdrop_rate
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -278,9 +276,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
# many copies of this extra gradient term.
|
# many copies of this extra gradient term.
|
||||||
self.out_whiten = Whiten(
|
self.out_whiten = Whiten(
|
||||||
num_groups=1,
|
num_groups=1,
|
||||||
whitening_limit=ScheduledFloat(
|
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
|
||||||
),
|
|
||||||
prob=(0.025, 0.25),
|
prob=(0.025, 0.25),
|
||||||
grad_scale=0.02,
|
grad_scale=0.02,
|
||||||
)
|
)
|
||||||
@ -331,7 +327,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
x_lens = (x_lens - 7) // 2
|
x_lens = (x_lens - 7) // 2
|
||||||
assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max())
|
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
|
||||||
|
|
||||||
return x, x_lens
|
return x, x_lens
|
||||||
|
|
||||||
@ -403,8 +399,8 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
left_pad = self.convnext.padding[0]
|
left_pad = self.convnext.padding[0]
|
||||||
freq = self.out_width
|
freq = self.out_width
|
||||||
channels = self.layer3_channels
|
channels = self.layer3_channels
|
||||||
cached_embed_left_pad = torch.zeros(
|
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||||
batch_size, channels, left_pad, freq
|
device
|
||||||
).to(device)
|
)
|
||||||
|
|
||||||
return cached_embed_left_pad
|
return cached_embed_left_pad
|
||||||
|
@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
params.use_transducer or params.use_ctc
|
f"At least one of them should be True, "
|
||||||
), (f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
@ -808,17 +808,16 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += (
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
@ -1166,7 +1165,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -981,7 +981,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -746,7 +746,6 @@ def train_one_epoch(
|
|||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
|
if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
@ -966,7 +965,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -1019,7 +1018,6 @@ def run(rank, world_size, args):
|
|||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
|
|
||||||
scheduler.step_epoch(epoch - 1)
|
scheduler.step_epoch(epoch - 1)
|
||||||
fix_random_seed(params.seed + epoch - 1)
|
fix_random_seed(params.seed + epoch - 1)
|
||||||
train_dl.sampler.set_epoch(epoch - 1)
|
train_dl.sampler.set_epoch(epoch - 1)
|
||||||
@ -1118,7 +1116,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
|
|
||||||
loss, _, _ = compute_loss(
|
loss, _, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1164,7 +1164,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -915,7 +915,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer import Zipformer
|
from zipformer import Zipformer
|
||||||
|
|
||||||
from icefall import diagnostics, byte_encode, tokenize_by_CJK_char
|
from icefall import byte_encode, diagnostics, tokenize_by_CJK_char
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -1018,7 +1018,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -905,7 +905,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1126,7 +1126,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -886,7 +886,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -851,7 +851,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -985,7 +985,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1128,7 +1128,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1001,7 +1001,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -993,7 +993,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
512
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
@ -1,12 +1,6 @@
|
|||||||
# isort:skip_file
|
# isort:skip_file
|
||||||
|
|
||||||
from . import (
|
from . import checkpoint, decode, dist, env, utils
|
||||||
checkpoint,
|
|
||||||
decode,
|
|
||||||
dist,
|
|
||||||
env,
|
|
||||||
utils
|
|
||||||
)
|
|
||||||
|
|
||||||
from .byte_utils import (
|
from .byte_utils import (
|
||||||
byte_decode,
|
byte_decode,
|
||||||
|
@ -227,7 +227,6 @@ class ContextGraph:
|
|||||||
filename: Optional[str] = "",
|
filename: Optional[str] = "",
|
||||||
symbol_table: Optional[Dict[int, str]] = None,
|
symbol_table: Optional[Dict[int, str]] = None,
|
||||||
) -> "Digraph": # noqa
|
) -> "Digraph": # noqa
|
||||||
|
|
||||||
"""Visualize a ContextGraph via graphviz.
|
"""Visualize a ContextGraph via graphviz.
|
||||||
|
|
||||||
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
Render ContextGraph as an image via graphviz, and return the Digraph object;
|
||||||
|
@ -23,6 +23,7 @@ from typing import Optional, Tuple, List
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class TensorDiagnosticOptions(object):
|
class TensorDiagnosticOptions(object):
|
||||||
"""Options object for tensor diagnostics:
|
"""Options object for tensor diagnostics:
|
||||||
|
|
||||||
@ -77,11 +78,11 @@ def get_tensor_stats(
|
|||||||
elif stats_type == "abs":
|
elif stats_type == "abs":
|
||||||
x = x.abs()
|
x = x.abs()
|
||||||
elif stats_type == "rms":
|
elif stats_type == "rms":
|
||||||
x = x ** 2
|
x = x**2
|
||||||
elif stats_type == "positive":
|
elif stats_type == "positive":
|
||||||
x = (x > 0).to(dtype=torch.float)
|
x = (x > 0).to(dtype=torch.float)
|
||||||
else:
|
else:
|
||||||
assert stats_type in [ "value", "max", "min" ]
|
assert stats_type in ["value", "max", "min"]
|
||||||
|
|
||||||
sum_dims = [d for d in range(x.ndim) if d != dim]
|
sum_dims = [d for d in range(x.ndim) if d != dim]
|
||||||
if len(sum_dims) > 0:
|
if len(sum_dims) > 0:
|
||||||
@ -121,10 +122,10 @@ class TensorDiagnostic(object):
|
|||||||
self.class_name = None # will assign in accumulate()
|
self.class_name = None # will assign in accumulate()
|
||||||
|
|
||||||
self.stats = None # we'll later assign a list to self.stats.
|
self.stats = None # we'll later assign a list to self.stats.
|
||||||
# It's a list of dicts, indexed by dim (i.e. by the
|
# It's a list of dicts, indexed by dim (i.e. by the
|
||||||
# axis of the tensor). The dicts, in turn, are
|
# axis of the tensor). The dicts, in turn, are
|
||||||
# indexed by `stats-type` which are strings in
|
# indexed by `stats-type` which are strings in
|
||||||
# ["abs", "max", "min", "positive", "value", "rms"].
|
# ["abs", "max", "min", "positive", "value", "rms"].
|
||||||
|
|
||||||
# scalar_stats contains some analysis of the activations and gradients,
|
# scalar_stats contains some analysis of the activations and gradients,
|
||||||
self.scalar_stats = None
|
self.scalar_stats = None
|
||||||
@ -139,7 +140,6 @@ class TensorDiagnostic(object):
|
|||||||
# only adding a new element to the list if there was a different dim.
|
# only adding a new element to the list if there was a different dim.
|
||||||
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
# if the string in the key is "eigs", if we detect a length mismatch we put None as the value.
|
||||||
|
|
||||||
|
|
||||||
def accumulate(self, x, class_name: Optional[str] = None):
|
def accumulate(self, x, class_name: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Accumulate tensors.
|
Accumulate tensors.
|
||||||
@ -193,17 +193,12 @@ class TensorDiagnostic(object):
|
|||||||
done = True
|
done = True
|
||||||
break
|
break
|
||||||
if not done:
|
if not done:
|
||||||
if (
|
if this_dim_stats[stats_type] != [] and stats_type == "eigs":
|
||||||
this_dim_stats[stats_type] != []
|
|
||||||
and stats_type == "eigs"
|
|
||||||
):
|
|
||||||
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
# >1 size encountered on this dim, e.g. it's a batch or time dimension,
|
||||||
# don't accumulat "eigs" stats type, it uses too much memory
|
# don't accumulat "eigs" stats type, it uses too much memory
|
||||||
this_dim_stats[stats_type] = None
|
this_dim_stats[stats_type] = None
|
||||||
else:
|
else:
|
||||||
this_dim_stats[stats_type].append(
|
this_dim_stats[stats_type].append(TensorAndCount(stats, count))
|
||||||
TensorAndCount(stats, count)
|
|
||||||
)
|
|
||||||
|
|
||||||
def print_diagnostics(self):
|
def print_diagnostics(self):
|
||||||
"""Print diagnostics for each dimension of the tensor."""
|
"""Print diagnostics for each dimension of the tensor."""
|
||||||
@ -220,8 +215,11 @@ class TensorDiagnostic(object):
|
|||||||
for r, v in zip(rms_stats_list, value_stats_list):
|
for r, v in zip(rms_stats_list, value_stats_list):
|
||||||
stddev_stats_list.append(
|
stddev_stats_list.append(
|
||||||
# r.count and v.count should be the same, but we don't check this.
|
# r.count and v.count should be the same, but we don't check this.
|
||||||
TensorAndCount(r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
|
TensorAndCount(
|
||||||
r.count))
|
r.tensor - v.tensor * v.tensor / (v.count + 1.0e-20),
|
||||||
|
r.count,
|
||||||
|
)
|
||||||
|
)
|
||||||
this_dim_stats["stddev"] = stddev_stats_list
|
this_dim_stats["stddev"] = stddev_stats_list
|
||||||
|
|
||||||
for stats_type, stats_list in this_dim_stats.items():
|
for stats_type, stats_list in this_dim_stats.items():
|
||||||
@ -232,7 +230,6 @@ class TensorDiagnostic(object):
|
|||||||
assert stats_type == "eigs"
|
assert stats_type == "eigs"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
def get_count(count):
|
def get_count(count):
|
||||||
return 1 if stats_type in ["max", "min"] else count
|
return 1 if stats_type in ["max", "min"] else count
|
||||||
|
|
||||||
@ -250,22 +247,20 @@ class TensorDiagnostic(object):
|
|||||||
eigs, _ = torch.symeig(stats)
|
eigs, _ = torch.symeig(stats)
|
||||||
stats = eigs.abs().sqrt()
|
stats = eigs.abs().sqrt()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
print(
|
print("Error getting eigenvalues, trying another method.")
|
||||||
"Error getting eigenvalues, trying another method."
|
|
||||||
)
|
|
||||||
eigs, _ = torch.eig(stats)
|
eigs, _ = torch.eig(stats)
|
||||||
stats = eigs.norm(dim=1).sqrt()
|
stats = eigs.norm(dim=1).sqrt()
|
||||||
# sqrt so it reflects data magnitude, like stddev- not variance
|
# sqrt so it reflects data magnitude, like stddev- not variance
|
||||||
|
|
||||||
if stats_type in [ "rms", "stddev" ]:
|
if stats_type in ["rms", "stddev"]:
|
||||||
# we stored the square; after aggregation we need to take sqrt.
|
# we stored the square; after aggregation we need to take sqrt.
|
||||||
stats = stats.sqrt()
|
stats = stats.sqrt()
|
||||||
|
|
||||||
# if `summarize` we print percentiles of the stats; else,
|
# if `summarize` we print percentiles of the stats; else,
|
||||||
# we print out individual elements.
|
# we print out individual elements.
|
||||||
summarize = (
|
summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized(
|
||||||
len(stats_list) > 1
|
stats.numel()
|
||||||
) or self.opts.dim_is_summarized(stats.numel())
|
)
|
||||||
if summarize: # usually `summarize` will be true
|
if summarize: # usually `summarize` will be true
|
||||||
# print out percentiles.
|
# print out percentiles.
|
||||||
stats = stats.sort()[0]
|
stats = stats.sort()[0]
|
||||||
@ -282,15 +277,15 @@ class TensorDiagnostic(object):
|
|||||||
ans = stats.tolist()
|
ans = stats.tolist()
|
||||||
ans = ["%.2g" % x for x in ans]
|
ans = ["%.2g" % x for x in ans]
|
||||||
ans = "[" + " ".join(ans) + "]"
|
ans = "[" + " ".join(ans) + "]"
|
||||||
if stats_type in [ "value", "rms", "stddev", "eigs" ]:
|
if stats_type in ["value", "rms", "stddev", "eigs"]:
|
||||||
# This norm is useful because it is strictly less than the largest
|
# This norm is useful because it is strictly less than the largest
|
||||||
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
# sqrt(eigenvalue) of the variance, which we print out, and shows,
|
||||||
# speaking in an approximate way, how much of that largest eigenvalue
|
# speaking in an approximate way, how much of that largest eigenvalue
|
||||||
# can be attributed to the mean of the distribution.
|
# can be attributed to the mean of the distribution.
|
||||||
norm = (stats ** 2).sum().sqrt().item()
|
norm = (stats**2).sum().sqrt().item()
|
||||||
ans += f", norm={norm:.2g}"
|
ans += f", norm={norm:.2g}"
|
||||||
mean = stats.mean().item()
|
mean = stats.mean().item()
|
||||||
rms = (stats ** 2).mean().sqrt().item()
|
rms = (stats**2).mean().sqrt().item()
|
||||||
ans += f", mean={mean:.3g}, rms={rms:.3g}"
|
ans += f", mean={mean:.3g}, rms={rms:.3g}"
|
||||||
|
|
||||||
# OK, "ans" contains the actual stats, e.g.
|
# OK, "ans" contains the actual stats, e.g.
|
||||||
@ -298,11 +293,11 @@ class TensorDiagnostic(object):
|
|||||||
|
|
||||||
sizes = [x.tensor.shape[0] for x in stats_list]
|
sizes = [x.tensor.shape[0] for x in stats_list]
|
||||||
size_str = (
|
size_str = (
|
||||||
f"{sizes[0]}"
|
f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}"
|
||||||
if len(sizes) == 1
|
)
|
||||||
else f"{min(sizes)}..{max(sizes)}"
|
maybe_class_name = (
|
||||||
|
f" type={self.class_name}," if self.class_name is not None else ""
|
||||||
)
|
)
|
||||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
|
||||||
print(
|
print(
|
||||||
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
|
f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}"
|
||||||
)
|
)
|
||||||
@ -330,7 +325,6 @@ class ScalarDiagnostic(object):
|
|||||||
self.sum_gradsq = None
|
self.sum_gradsq = None
|
||||||
self.sum_abs_grad = None
|
self.sum_abs_grad = None
|
||||||
|
|
||||||
|
|
||||||
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
|
def accumulate_input(self, x: Tensor, class_name: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Called in forward pass.
|
Called in forward pass.
|
||||||
@ -347,8 +341,10 @@ class ScalarDiagnostic(object):
|
|||||||
|
|
||||||
limit = 10
|
limit = 10
|
||||||
if len(self.saved_inputs) > limit:
|
if len(self.saved_inputs) > limit:
|
||||||
print(f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
|
print(
|
||||||
f" Will not accumulate scalar stats.")
|
f"ERROR: forward pass called for this module over {limit} times with no backward pass. "
|
||||||
|
f" Will not accumulate scalar stats."
|
||||||
|
)
|
||||||
self.is_ok = False
|
self.is_ok = False
|
||||||
return
|
return
|
||||||
self.saved_inputs.append(x)
|
self.saved_inputs.append(x)
|
||||||
@ -359,11 +355,15 @@ class ScalarDiagnostic(object):
|
|||||||
if self.is_forward_pass:
|
if self.is_forward_pass:
|
||||||
self.is_forward_pass = False
|
self.is_forward_pass = False
|
||||||
|
|
||||||
last_shape = 'n/a' if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
|
last_shape = (
|
||||||
|
"n/a" if len(self.saved_inputs) == 0 else self.saved_inputs[-1].shape
|
||||||
|
)
|
||||||
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
|
if len(self.saved_inputs) == 0 or grad.shape != last_shape:
|
||||||
print(f"ERROR: shape mismatch or no forward activation present when backward "
|
print(
|
||||||
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
|
f"ERROR: shape mismatch or no forward activation present when backward "
|
||||||
f", shape-of-last-saved-input={last_shape}")
|
f"pass called: grad shape ={tuple(grad.shape)}, num-saved-inputs={len(self.saved_inputs)}"
|
||||||
|
f", shape-of-last-saved-input={last_shape}"
|
||||||
|
)
|
||||||
self.is_ok = False
|
self.is_ok = False
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -384,11 +384,19 @@ class ScalarDiagnostic(object):
|
|||||||
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
|
self.tick_scale = float(x_abs_sorted[index] / num_ticks_per_side)
|
||||||
|
|
||||||
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
|
# integerize from tick * (-num ticks_per_side .. num_ticks_per_side - 1]
|
||||||
self.counts = torch.zeros(2 * num_ticks_per_side, dtype=torch.long, device=x.device)
|
self.counts = torch.zeros(
|
||||||
self.sum_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
2 * num_ticks_per_side, dtype=torch.long, device=x.device
|
||||||
|
)
|
||||||
|
self.sum_grad = torch.zeros(
|
||||||
|
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
||||||
|
)
|
||||||
# sum_gradsq is for getting error bars.
|
# sum_gradsq is for getting error bars.
|
||||||
self.sum_gradsq = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
self.sum_gradsq = torch.zeros(
|
||||||
self.sum_abs_grad = torch.zeros(2 * num_ticks_per_side, dtype=torch.double, device=x.device)
|
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
||||||
|
)
|
||||||
|
self.sum_abs_grad = torch.zeros(
|
||||||
|
2 * num_ticks_per_side, dtype=torch.double, device=x.device
|
||||||
|
)
|
||||||
|
|
||||||
# this will round down.
|
# this will round down.
|
||||||
x = (x / self.tick_scale).to(torch.long)
|
x = (x / self.tick_scale).to(torch.long)
|
||||||
@ -397,20 +405,21 @@ class ScalarDiagnostic(object):
|
|||||||
|
|
||||||
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
self.counts.index_add_(dim=0, index=x, source=torch.ones_like(x))
|
||||||
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
self.sum_grad.index_add_(dim=0, index=x, source=grad.to(torch.double))
|
||||||
self.sum_gradsq.index_add_(dim=0, index=x, source=(grad*grad).to(torch.double))
|
self.sum_gradsq.index_add_(
|
||||||
|
dim=0, index=x, source=(grad * grad).to(torch.double)
|
||||||
|
)
|
||||||
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
|
self.sum_abs_grad.index_add_(dim=0, index=x, source=grad.abs().to(torch.double))
|
||||||
|
|
||||||
|
|
||||||
def print_diagnostics(self):
|
def print_diagnostics(self):
|
||||||
"""Print diagnostics."""
|
"""Print diagnostics."""
|
||||||
if self.is_ok is False or self.counts is None:
|
if self.is_ok is False or self.counts is None:
|
||||||
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
|
print(f"Warning: no stats accumulated for {self.name}, is_ok={self.is_ok}")
|
||||||
return
|
return
|
||||||
|
|
||||||
counts = self.counts.to('cpu')
|
counts = self.counts.to("cpu")
|
||||||
sum_grad = self.sum_grad.to(device='cpu', dtype=torch.float32)
|
sum_grad = self.sum_grad.to(device="cpu", dtype=torch.float32)
|
||||||
sum_gradsq = self.sum_gradsq.to(device='cpu', dtype=torch.float32)
|
sum_gradsq = self.sum_gradsq.to(device="cpu", dtype=torch.float32)
|
||||||
sum_abs_grad = self.sum_abs_grad.to(device='cpu', dtype=torch.float32)
|
sum_abs_grad = self.sum_abs_grad.to(device="cpu", dtype=torch.float32)
|
||||||
|
|
||||||
counts_cumsum = counts.cumsum(dim=0)
|
counts_cumsum = counts.cumsum(dim=0)
|
||||||
counts_tot = counts_cumsum[-1]
|
counts_tot = counts_cumsum[-1]
|
||||||
@ -433,19 +442,22 @@ class ScalarDiagnostic(object):
|
|||||||
bin_abs_grad = torch.zeros(num_bins)
|
bin_abs_grad = torch.zeros(num_bins)
|
||||||
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
|
bin_abs_grad.index_add_(dim=0, index=bin_indexes, source=sum_abs_grad)
|
||||||
|
|
||||||
avg_grad = (bin_grad / bin_counts)
|
avg_grad = bin_grad / bin_counts
|
||||||
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
|
avg_grad_stddev = (bin_gradsq / bin_counts).sqrt()
|
||||||
|
|
||||||
bin_boundary_counts = torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
bin_boundary_counts = (
|
||||||
|
torch.arange(num_bins + 1, dtype=torch.long) * counts_per_bin
|
||||||
|
)
|
||||||
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
|
bin_tick_indexes = torch.searchsorted(counts_cumsum, bin_boundary_counts)
|
||||||
# boundaries are the "x" values between the bins, e.g. corresponding to the
|
# boundaries are the "x" values between the bins, e.g. corresponding to the
|
||||||
# locations of percentiles of the distribution.
|
# locations of percentiles of the distribution.
|
||||||
num_ticks_per_side = counts.numel() // 2
|
num_ticks_per_side = counts.numel() // 2
|
||||||
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
|
bin_boundaries = (bin_tick_indexes - num_ticks_per_side) * self.tick_scale
|
||||||
|
|
||||||
|
|
||||||
bin_grad = bin_grad / (bin_counts + 1)
|
bin_grad = bin_grad / (bin_counts + 1)
|
||||||
bin_conf_interval = bin_gradsq.sqrt() / (bin_counts + 1) # consider this a standard deviation.
|
bin_conf_interval = bin_gradsq.sqrt() / (
|
||||||
|
bin_counts + 1
|
||||||
|
) # consider this a standard deviation.
|
||||||
# bin_grad / bin_abs_grad will give us a sense for how important in a practical sense,
|
# bin_grad / bin_abs_grad will give us a sense for how important in a practical sense,
|
||||||
# the gradients are.
|
# the gradients are.
|
||||||
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
|
bin_abs_grad = bin_abs_grad / (bin_counts + 1)
|
||||||
@ -458,8 +470,9 @@ class ScalarDiagnostic(object):
|
|||||||
x = "[" + " ".join(x) + "]"
|
x = "[" + " ".join(x) + "]"
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
maybe_class_name = (
|
||||||
maybe_class_name = f" type={self.class_name}," if self.class_name is not None else ""
|
f" type={self.class_name}," if self.class_name is not None else ""
|
||||||
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, "
|
f"module={self.name},{maybe_class_name} bin-boundaries={tensor_to_str(bin_boundaries)}, "
|
||||||
@ -467,7 +480,6 @@ class ScalarDiagnostic(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelDiagnostic(object):
|
class ModelDiagnostic(object):
|
||||||
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
"""This class stores diagnostics for all tensors in the torch.nn.Module.
|
||||||
|
|
||||||
@ -485,9 +497,8 @@ class ModelDiagnostic(object):
|
|||||||
self.opts = opts
|
self.opts = opts
|
||||||
self.diagnostics = dict()
|
self.diagnostics = dict()
|
||||||
|
|
||||||
|
|
||||||
def __getitem__(self, name: str):
|
def __getitem__(self, name: str):
|
||||||
T = ScalarDiagnostic if name[-7:] == '.scalar' else TensorDiagnostic
|
T = ScalarDiagnostic if name[-7:] == ".scalar" else TensorDiagnostic
|
||||||
if name not in self.diagnostics:
|
if name not in self.diagnostics:
|
||||||
self.diagnostics[name] = T(self.opts, name)
|
self.diagnostics[name] = T(self.opts, name)
|
||||||
return self.diagnostics[name]
|
return self.diagnostics[name]
|
||||||
@ -502,18 +513,19 @@ def get_class_name(module: nn.Module):
|
|||||||
ans = type(module).__name__
|
ans = type(module).__name__
|
||||||
# we put the below in try blocks in case anyone is using a different version of these modules that
|
# we put the below in try blocks in case anyone is using a different version of these modules that
|
||||||
# might have different member names.
|
# might have different member names.
|
||||||
if ans == 'Balancer' or ans == 'ActivationBalancer':
|
if ans == "Balancer" or ans == "ActivationBalancer":
|
||||||
try:
|
try:
|
||||||
ans += f'[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]'
|
ans += f"[{float(module.min_positive)},{float(module.max_positive)},{float(module.min_abs)},{float(module.max_abs)}]"
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
elif ans == 'AbsValuePenalizer':
|
elif ans == "AbsValuePenalizer":
|
||||||
try:
|
try:
|
||||||
ans += f'[{module.limit}]'
|
ans += f"[{module.limit}]"
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def attach_diagnostics(
|
def attach_diagnostics(
|
||||||
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
|
model: nn.Module, opts: Optional[TensorDiagnosticOptions] = None
|
||||||
) -> ModelDiagnostic:
|
) -> ModelDiagnostic:
|
||||||
@ -538,73 +550,85 @@ def attach_diagnostics(
|
|||||||
if name == "":
|
if name == "":
|
||||||
name = "<top-level>"
|
name = "<top-level>"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
# Setting model_diagnostic=ans and n=name below, instead of trying to
|
||||||
# capture the variables, ensures that we use the current values.
|
# capture the variables, ensures that we use the current values.
|
||||||
# (this matters for `name`, since the variable gets overwritten).
|
# (this matters for `name`, since the variable gets overwritten).
|
||||||
# These closures don't really capture by value, only by
|
# These closures don't really capture by value, only by
|
||||||
# "the final value the variable got in the function" :-(
|
# "the final value the variable got in the function" :-(
|
||||||
def forward_hook(
|
def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
|
||||||
):
|
|
||||||
if isinstance(_output, tuple) and len(_output) == 1:
|
if isinstance(_output, tuple) and len(_output) == 1:
|
||||||
_output = _output[0]
|
_output = _output[0]
|
||||||
|
|
||||||
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
if isinstance(_output, Tensor) and _output.dtype in (
|
||||||
_model_diagnostic[f"{_name}.output"].accumulate(_output,
|
torch.float32,
|
||||||
class_name=get_class_name(_module))
|
torch.float16,
|
||||||
|
torch.float64,
|
||||||
|
):
|
||||||
|
_model_diagnostic[f"{_name}.output"].accumulate(
|
||||||
|
_output, class_name=get_class_name(_module)
|
||||||
|
)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
||||||
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(o,
|
_model_diagnostic[f"{_name}.output[{i}]"].accumulate(
|
||||||
class_name=get_class_name(_module))
|
o, class_name=get_class_name(_module)
|
||||||
|
)
|
||||||
|
|
||||||
def backward_hook(
|
def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name):
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
|
||||||
):
|
|
||||||
if isinstance(_output, tuple) and len(_output) == 1:
|
if isinstance(_output, tuple) and len(_output) == 1:
|
||||||
_output = _output[0]
|
_output = _output[0]
|
||||||
if isinstance(_output, Tensor) and _output.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
if isinstance(_output, Tensor) and _output.dtype in (
|
||||||
_model_diagnostic[f"{_name}.grad"].accumulate(_output,
|
torch.float32,
|
||||||
class_name=get_class_name(_module))
|
torch.float16,
|
||||||
|
torch.float64,
|
||||||
|
):
|
||||||
|
_model_diagnostic[f"{_name}.grad"].accumulate(
|
||||||
|
_output, class_name=get_class_name(_module)
|
||||||
|
)
|
||||||
elif isinstance(_output, tuple):
|
elif isinstance(_output, tuple):
|
||||||
for i, o in enumerate(_output):
|
for i, o in enumerate(_output):
|
||||||
if o.dtype in ( torch.float32, torch.float16, torch.float64 ):
|
if o.dtype in (torch.float32, torch.float16, torch.float64):
|
||||||
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o,
|
_model_diagnostic[f"{_name}.grad[{i}]"].accumulate(
|
||||||
class_name=get_class_name(_module))
|
o, class_name=get_class_name(_module)
|
||||||
|
)
|
||||||
|
|
||||||
module.register_forward_hook(forward_hook)
|
module.register_forward_hook(forward_hook)
|
||||||
module.register_backward_hook(backward_hook)
|
module.register_backward_hook(backward_hook)
|
||||||
|
|
||||||
if type(module).__name__ in ["Sigmoid", "Tanh", "ReLU", "TanSwish", "Swish", "DoubleSwish", "Swoosh"]:
|
if type(module).__name__ in [
|
||||||
|
"Sigmoid",
|
||||||
|
"Tanh",
|
||||||
|
"ReLU",
|
||||||
|
"TanSwish",
|
||||||
|
"Swish",
|
||||||
|
"DoubleSwish",
|
||||||
|
"Swoosh",
|
||||||
|
]:
|
||||||
# For these specific module types, accumulate some additional diagnostics
|
# For these specific module types, accumulate some additional diagnostics
|
||||||
# that can help us improve the activation function. These require a lot of memory,
|
# that can help us improve the activation function. These require a lot of memory,
|
||||||
# to save the forward activations, so limit this to some select classes.
|
# to save the forward activations, so limit this to some select classes.
|
||||||
# Note: this will not work correctly for all model types.
|
# Note: this will not work correctly for all model types.
|
||||||
def scalar_forward_hook(
|
def scalar_forward_hook(
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
):
|
):
|
||||||
if isinstance(_input, tuple):
|
if isinstance(_input, tuple):
|
||||||
_input, = _input
|
(_input,) = _input
|
||||||
assert isinstance(_input, Tensor)
|
assert isinstance(_input, Tensor)
|
||||||
_model_diagnostic[f"{_name}.scalar"].accumulate_input(_input,
|
_model_diagnostic[f"{_name}.scalar"].accumulate_input(
|
||||||
class_name=get_class_name(_module))
|
_input, class_name=get_class_name(_module)
|
||||||
|
)
|
||||||
|
|
||||||
def scalar_backward_hook(
|
def scalar_backward_hook(
|
||||||
_module, _input, _output, _model_diagnostic=ans, _name=name
|
_module, _input, _output, _model_diagnostic=ans, _name=name
|
||||||
):
|
):
|
||||||
if isinstance(_output, tuple):
|
if isinstance(_output, tuple):
|
||||||
_output, = _output
|
(_output,) = _output
|
||||||
assert isinstance(_output, Tensor)
|
assert isinstance(_output, Tensor)
|
||||||
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
|
_model_diagnostic[f"{_name}.scalar"].accumulate_output_grad(_output)
|
||||||
|
|
||||||
module.register_forward_hook(scalar_forward_hook)
|
module.register_forward_hook(scalar_forward_hook)
|
||||||
module.register_backward_hook(scalar_backward_hook)
|
module.register_backward_hook(scalar_backward_hook)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for name, parameter in model.named_parameters():
|
for name, parameter in model.named_parameters():
|
||||||
|
|
||||||
def param_backward_hook(
|
def param_backward_hook(
|
||||||
|
@ -70,25 +70,17 @@ class FlopsProfiler(object):
|
|||||||
module_flop_count.append([])
|
module_flop_count.append([])
|
||||||
|
|
||||||
if not hasattr(module, "__pre_hook_handle__"):
|
if not hasattr(module, "__pre_hook_handle__"):
|
||||||
module.__pre_hook_handle__ = module.register_forward_pre_hook(
|
module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook)
|
||||||
pre_hook
|
|
||||||
)
|
|
||||||
|
|
||||||
def post_hook(module, input, output):
|
def post_hook(module, input, output):
|
||||||
if module_flop_count:
|
if module_flop_count:
|
||||||
module.__flops__ += sum(
|
module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]])
|
||||||
[elem[1] for elem in module_flop_count[-1]]
|
|
||||||
)
|
|
||||||
module_flop_count.pop()
|
module_flop_count.pop()
|
||||||
|
|
||||||
if not hasattr(module, "__post_hook_handle__"):
|
if not hasattr(module, "__post_hook_handle__"):
|
||||||
module.__post_hook_handle__ = module.register_forward_hook(
|
module.__post_hook_handle__ = module.register_forward_hook(post_hook)
|
||||||
post_hook
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model.apply(
|
self.model.apply(partial(register_module_hooks, ignore_list=ignore_list))
|
||||||
partial(register_module_hooks, ignore_list=ignore_list)
|
|
||||||
)
|
|
||||||
self.started = True
|
self.started = True
|
||||||
self.func_patched = True
|
self.func_patched = True
|
||||||
|
|
||||||
@ -194,9 +186,7 @@ def _prelu_flops_compute(input: Tensor, weight: Tensor):
|
|||||||
return input.numel()
|
return input.numel()
|
||||||
|
|
||||||
|
|
||||||
def _elu_flops_compute(
|
def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False):
|
||||||
input: Tensor, alpha: float = 1.0, inplace: bool = False
|
|
||||||
):
|
|
||||||
return input.numel()
|
return input.numel()
|
||||||
|
|
||||||
|
|
||||||
@ -259,9 +249,7 @@ def _conv_flops_compute(
|
|||||||
output_dims.append(output_dim)
|
output_dims.append(output_dim)
|
||||||
|
|
||||||
filters_per_channel = out_channels // groups
|
filters_per_channel = out_channels // groups
|
||||||
conv_per_position_macs = (
|
conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
||||||
int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
|
||||||
)
|
|
||||||
active_elements_count = batch_size * int(_prod(output_dims))
|
active_elements_count = batch_size * int(_prod(output_dims))
|
||||||
overall_conv_macs = conv_per_position_macs * active_elements_count
|
overall_conv_macs = conv_per_position_macs * active_elements_count
|
||||||
overall_conv_flops = 2 * overall_conv_macs
|
overall_conv_flops = 2 * overall_conv_macs
|
||||||
@ -297,7 +285,6 @@ def _conv_trans_flops_compute(
|
|||||||
|
|
||||||
output_dims = []
|
output_dims = []
|
||||||
for idx, input_dim in enumerate(input_dims):
|
for idx, input_dim in enumerate(input_dims):
|
||||||
|
|
||||||
output_dim = (
|
output_dim = (
|
||||||
input_dim
|
input_dim
|
||||||
+ 2 * paddings[idx]
|
+ 2 * paddings[idx]
|
||||||
@ -310,9 +297,7 @@ def _conv_trans_flops_compute(
|
|||||||
dilations = dilation if type(dilation) is tuple else (dilation, dilation)
|
dilations = dilation if type(dilation) is tuple else (dilation, dilation)
|
||||||
|
|
||||||
filters_per_channel = out_channels // groups
|
filters_per_channel = out_channels // groups
|
||||||
conv_per_position_macs = (
|
conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
||||||
int(_prod(kernel_dims)) * in_channels * filters_per_channel
|
|
||||||
)
|
|
||||||
active_elements_count = batch_size * int(_prod(input_dims))
|
active_elements_count = batch_size * int(_prod(input_dims))
|
||||||
overall_conv_macs = conv_per_position_macs * active_elements_count
|
overall_conv_macs = conv_per_position_macs * active_elements_count
|
||||||
overall_conv_flops = 2 * overall_conv_macs
|
overall_conv_flops = 2 * overall_conv_macs
|
||||||
@ -389,9 +374,7 @@ def _upsample_flops_compute(input, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return int(size), 0
|
return int(size), 0
|
||||||
scale_factor = kwargs.get("scale_factor", None)
|
scale_factor = kwargs.get("scale_factor", None)
|
||||||
assert (
|
assert scale_factor is not None, "either size or scale_factor should be defined"
|
||||||
scale_factor is not None
|
|
||||||
), "either size or scale_factor should be defined"
|
|
||||||
flops = input.numel()
|
flops = input.numel()
|
||||||
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
|
if isinstance(scale_factor, tuple) and len(scale_factor) == len(input):
|
||||||
flops * int(_prod(scale_factor))
|
flops * int(_prod(scale_factor))
|
||||||
@ -593,12 +576,8 @@ def _patch_functionals():
|
|||||||
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
|
F.embedding = wrapFunc(F.embedding, _embedding_flops_compute)
|
||||||
|
|
||||||
# swoosh functions in k2
|
# swoosh functions in k2
|
||||||
k2.swoosh_l_forward = wrapFunc(
|
k2.swoosh_l_forward = wrapFunc(k2.swoosh_l_forward, _k2_swoosh_flops_compute)
|
||||||
k2.swoosh_l_forward, _k2_swoosh_flops_compute
|
k2.swoosh_r_forward = wrapFunc(k2.swoosh_r_forward, _k2_swoosh_flops_compute)
|
||||||
)
|
|
||||||
k2.swoosh_r_forward = wrapFunc(
|
|
||||||
k2.swoosh_r_forward, _k2_swoosh_flops_compute
|
|
||||||
)
|
|
||||||
k2.swoosh_l = wrapFunc(k2.swoosh_l, _k2_swoosh_flops_compute)
|
k2.swoosh_l = wrapFunc(k2.swoosh_l, _k2_swoosh_flops_compute)
|
||||||
k2.swoosh_r = wrapFunc(k2.swoosh_r, _k2_swoosh_flops_compute)
|
k2.swoosh_r = wrapFunc(k2.swoosh_r, _k2_swoosh_flops_compute)
|
||||||
|
|
||||||
@ -612,9 +591,7 @@ def _patch_tensor_methods():
|
|||||||
torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute)
|
torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute)
|
||||||
|
|
||||||
torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
|
torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute)
|
||||||
torch.Tensor.addmm = wrapFunc(
|
torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute)
|
||||||
torch.Tensor.addmm, _tensor_addmm_flops_compute
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.mul = wrapFunc(torch.mul, _mul_flops_compute)
|
torch.mul = wrapFunc(torch.mul, _mul_flops_compute)
|
||||||
torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute)
|
torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute)
|
||||||
@ -631,14 +608,10 @@ def _patch_tensor_methods():
|
|||||||
|
|
||||||
torch.tanh = wrapFunc(torch.tanh, _tanh_flops_compute)
|
torch.tanh = wrapFunc(torch.tanh, _tanh_flops_compute)
|
||||||
|
|
||||||
torch.Tensor.softmax = wrapFunc(
|
torch.Tensor.softmax = wrapFunc(torch.Tensor.softmax, _softmax_flops_compute)
|
||||||
torch.Tensor.softmax, _softmax_flops_compute
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.sigmoid = wrapFunc(torch.sigmoid, _sigmoid_flops_compute)
|
torch.sigmoid = wrapFunc(torch.sigmoid, _sigmoid_flops_compute)
|
||||||
torch.Tensor.sigmoid = wrapFunc(
|
torch.Tensor.sigmoid = wrapFunc(torch.Tensor.sigmoid, _sigmoid_flops_compute)
|
||||||
torch.Tensor.sigmoid, _sigmoid_flops_compute
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _reload_functionals():
|
def _reload_functionals():
|
||||||
@ -732,15 +705,11 @@ def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
|
|||||||
flops += rnn_module.hidden_size * 4
|
flops += rnn_module.hidden_size * 4
|
||||||
# two hadamard _product and add for C state
|
# two hadamard _product and add for C state
|
||||||
flops += (
|
flops += (
|
||||||
rnn_module.hidden_size
|
rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
|
||||||
+ rnn_module.hidden_size
|
|
||||||
+ rnn_module.hidden_size
|
|
||||||
)
|
)
|
||||||
# final hadamard
|
# final hadamard
|
||||||
flops += (
|
flops += (
|
||||||
rnn_module.hidden_size
|
rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
|
||||||
+ rnn_module.hidden_size
|
|
||||||
+ rnn_module.hidden_size
|
|
||||||
)
|
)
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
@ -112,7 +112,6 @@ def main():
|
|||||||
for torch_v, onnx_v in zip(
|
for torch_v, onnx_v in zip(
|
||||||
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
|
(torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0)
|
||||||
):
|
):
|
||||||
|
|
||||||
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
assert torch.allclose(torch_v, onnx_v, atol=1e-5), (
|
||||||
torch_v.shape,
|
torch_v.shape,
|
||||||
onnx_v.shape,
|
onnx_v.shape,
|
||||||
|
@ -463,7 +463,6 @@ def train_one_epoch(
|
|||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
|
|
||||||
if batch_idx < cur_batch_idx:
|
if batch_idx < cur_batch_idx:
|
||||||
continue
|
continue
|
||||||
cur_batch_idx = batch_idx
|
cur_batch_idx = batch_idx
|
||||||
|
@ -225,7 +225,6 @@ class NgramCounts:
|
|||||||
for n in range(0, self.ngram_order - 1):
|
for n in range(0, self.ngram_order - 1):
|
||||||
this_order_counts = self.counts[n]
|
this_order_counts = self.counts[n]
|
||||||
for hist, counts_for_hist in this_order_counts.items():
|
for hist, counts_for_hist in this_order_counts.items():
|
||||||
|
|
||||||
n_star_star = 0
|
n_star_star = 0
|
||||||
for w in counts_for_hist.word_to_count.keys():
|
for w in counts_for_hist.word_to_count.keys():
|
||||||
n_star_star += len(counts_for_hist.word_to_context[w])
|
n_star_star += len(counts_for_hist.word_to_context[w])
|
||||||
@ -424,7 +423,6 @@ class NgramCounts:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
ngram_counts = NgramCounts(args.ngram_order)
|
ngram_counts = NgramCounts(args.ngram_order)
|
||||||
|
|
||||||
if args.text is None:
|
if args.text is None:
|
||||||
|
@ -103,7 +103,6 @@ class TransformerLM(torch.nn.Module):
|
|||||||
return nll_loss
|
return nll_loss
|
||||||
|
|
||||||
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
|
def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None):
|
||||||
|
|
||||||
bs = x.size(0)
|
bs = x.size(0)
|
||||||
|
|
||||||
state = None
|
state = None
|
||||||
|
@ -20,6 +20,7 @@ kaldialign==0.7.1
|
|||||||
sentencepiece==0.1.96
|
sentencepiece==0.1.96
|
||||||
tensorboard==2.8.0
|
tensorboard==2.8.0
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
|
black==22.3.0
|
||||||
multi_quantization
|
multi_quantization
|
||||||
|
|
||||||
onnx
|
onnx
|
||||||
|
@ -5,3 +5,4 @@ sentencepiece>=0.1.96
|
|||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
dill
|
dill
|
||||||
|
black==22.3.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user