Merge branch 'dev/bilingual' of https://github.com/jinzr/icefall into dev/bilingual

This commit is contained in:
JinZr 2023-09-26 13:34:36 +08:00
commit 851929b47f
111 changed files with 734 additions and 650 deletions

View File

@ -635,7 +635,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -800,7 +799,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -872,7 +872,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1045,7 +1045,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -322,6 +322,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -151,12 +151,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -170,6 +172,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -1028,7 +1028,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1031,7 +1031,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1019,7 +1019,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -730,7 +730,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -919,7 +918,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -908,7 +908,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -635,7 +635,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -800,7 +799,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -999,7 +999,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -988,7 +988,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -330,6 +330,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -152,12 +152,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -171,6 +173,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -1019,7 +1019,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1074,7 +1074,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1075,7 +1075,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -557,7 +557,6 @@ def train_one_epoch(
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train

View File

@ -953,7 +953,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -401,6 +401,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -136,6 +136,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
@ -184,6 +185,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -197,6 +199,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -953,7 +953,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -955,7 +955,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -43,6 +43,7 @@ from pathlib import Path
from tqdm.auto import tqdm
# This function is copied from lhotse
def tqdm_urlretrieve_hook(t):
"""Wraps tqdm instance.

View File

@ -236,7 +236,7 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -507,7 +507,7 @@ def modified_beam_search(
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]

View File

@ -162,7 +162,6 @@ def merge_chunks(
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for cut in cuts_chunk:
cur_rec_id = cut.recording.id
if len(cut_list) == 0:

View File

@ -264,6 +264,7 @@ def decode_dataset(
- timestamps of reference transcript
- timestamps of predicted result
"""
# Background worker to add alignemnt and save cuts to disk.
def _save_worker(
cuts: List[Cut],

View File

@ -359,6 +359,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -356,6 +356,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -129,6 +129,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
@ -166,6 +167,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -179,6 +181,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -172,30 +172,35 @@ class Model:
self.encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_decoder(self, args):
self.decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_joiner(self, args):
self.joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_joiner_encoder_proj(self, args):
self.joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_joiner_decoder_proj(self, args):
self.joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

View File

@ -66,7 +66,6 @@ class Eve(Optimizer):
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:

View File

@ -811,7 +811,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -307,6 +307,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -719,7 +719,7 @@ def greedy_search_batch(
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -1019,7 +1019,7 @@ def modified_beam_search(
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -1227,7 +1227,7 @@ def modified_beam_search_lm_rescore(
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -1427,7 +1427,7 @@ def modified_beam_search_lm_rescore_LODR(
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
@ -2608,7 +2608,6 @@ def modified_beam_search_LODR(
context_score = 0
new_context_state = None if context_graph is None else hyp.context_state
if new_token not in (blank_id, unk_id):
if context_graph is not None:
(
context_score,
@ -2758,7 +2757,7 @@ def modified_beam_search_lm_shallow_fusion(
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end] # get batch
@ -2900,7 +2899,6 @@ def modified_beam_search_lm_shallow_fusion(
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
ys.append(new_token)
new_timestamp.append(t)

View File

@ -66,7 +66,6 @@ class Eve(Optimizer):
weight_decay=1e-3,
target_rms=0.1,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:

View File

@ -528,7 +528,6 @@ class ScaledLSTM(nn.LSTM):
return
with torch.cuda.device_of(first_fw):
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
# an inplace operation on self._flat_weights
with torch.no_grad():

View File

@ -312,6 +312,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -150,12 +150,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -169,6 +171,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -78,6 +78,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -133,6 +134,7 @@ def test_rel_pos():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -304,6 +307,7 @@ def test_conformer_encoder():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -359,6 +363,7 @@ def test_conformer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()

View File

@ -404,6 +404,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -335,6 +335,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -138,6 +138,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
@ -185,6 +186,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -198,6 +200,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -1003,7 +1003,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -56,7 +56,6 @@ class CodebookIndexExtractor:
"""
def __init__(self, params: AttributeDict):
self.params = params
params.subsets = ["clean-100"]
if self.params.full_libri:

View File

@ -111,7 +111,7 @@ def batch_force_alignment(
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
for t, batch_size in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]

View File

@ -329,6 +329,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -1132,7 +1132,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -117,7 +117,7 @@ class BatchedOptimizer(Optimizer):
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
@ -224,9 +223,7 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
@ -325,7 +322,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
@ -410,7 +407,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@ -426,7 +423,6 @@ class ScaledAdam(BatchedOptimizer):
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
@ -1039,7 +1035,7 @@ def _test_scaled_adam(hidden_dim: int):
# if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# 512
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)

View File

@ -74,6 +74,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -128,6 +129,7 @@ def test_rel_pos():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -284,6 +287,7 @@ def test_zipformer_encoder():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()
@ -338,6 +342,7 @@ def test_zipformer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
input_nodes = session.get_inputs()

View File

@ -1028,7 +1028,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1052,7 +1052,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -326,41 +326,49 @@ def main():
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
lconv = ort.InferenceSession(
args.lconv_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
frame_reducer = ort.InferenceSession(
args.frame_reducer_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
ctc_output = ort.InferenceSession(
args.ctc_output_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)
sp = spm.SentencePieceProcessor()

View File

@ -1042,7 +1042,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -413,6 +413,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -401,6 +401,7 @@ def export_decoder_model_onnx(
context_size = decoder_model.decoder.context_size
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -130,6 +130,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
@ -229,6 +230,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -242,6 +244,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -1029,7 +1029,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1030,7 +1030,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -865,7 +865,7 @@ class ZipformerEncoderLayer(nn.Module):
return final_dropout_rate
else:
return initial_dropout_rate - (
initial_dropout_rate * final_dropout_rate
initial_dropout_rate - final_dropout_rate
) * (self.batch_count / warmup_period)
def forward(

View File

@ -1141,7 +1141,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1154,7 +1154,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -230,7 +230,9 @@ class Conformer(Transformer):
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
if self.normalize_before:
x = self.after_norm(x)

View File

@ -543,7 +543,6 @@ def train_one_epoch(
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train

View File

@ -463,7 +463,6 @@ def train_one_epoch(
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train

View File

@ -513,7 +513,6 @@ def train_one_epoch(
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train

View File

@ -517,7 +517,6 @@ def train_one_epoch(
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train

View File

@ -61,10 +61,15 @@ class Decoder(nn.Module):
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
self.blank_id = blank_id
@ -81,10 +86,15 @@ class Decoder(nn.Module):
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
min_positive=0.0, max_positive=1.0,
min_abs=0.5, max_abs=1.0,
prob=0.05)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
@ -107,9 +117,7 @@ class Decoder(nn.Module):
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output

View File

@ -506,6 +506,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -353,6 +353,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -52,12 +52,13 @@ class Joiner(nn.Module):
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out

View File

@ -303,7 +303,9 @@ def main():
for test_set, test_dl in zip(test_sets, test_dl):
start_time = time.time()
results, total_duration = decode_dataset(dl=test_dl, model=model, token_table=token_table)
results, total_duration = decode_dataset(
dl=test_dl, model=model, token_table=token_table
)
end_time = time.time()
elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / total_duration

View File

@ -146,6 +146,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
@ -236,6 +237,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -249,6 +251,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -151,12 +151,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -170,6 +172,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -116,7 +116,7 @@ class BatchedOptimizer(Optimizer):
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@ -181,7 +181,6 @@ class ScaledAdam(BatchedOptimizer):
size_update_period=4,
clipping_update_period=100,
):
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@ -299,8 +298,8 @@ class ScaledAdam(BatchedOptimizer):
# the input is groups of parameter or named parameter.
for cur_group in iterable_or_groups:
assert "named_params" in cur_group
name_list = [ x[0] for x in cur_group["named_params"] ]
p_list = [ x[1] for x in cur_group["named_params"] ]
name_list = [x[0] for x in cur_group["named_params"]]
p_list = [x[1] for x in cur_group["named_params"]]
del cur_group["named_params"]
cur_group["params"] = p_list
param_groups.append(cur_group)
@ -327,9 +326,7 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
@ -428,7 +425,7 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
@ -513,7 +510,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@ -529,7 +526,6 @@ class ScaledAdam(BatchedOptimizer):
for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
@ -667,8 +663,7 @@ class ScaledAdam(BatchedOptimizer):
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step,
(param_max_rms - param_rms) / param_rms)
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
@ -879,7 +874,8 @@ class Eden(LRScheduler):
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
)
@ -1111,7 +1107,7 @@ def _test_scaled_adam(hidden_dim: int):
# if epoch == 130:
# opts = diagnostics.TensorDiagnosticOptions(
# 2 ** 22
# 512
# ) # allow 4 megabytes per sub-module
# diagnostic = diagnostics.attach_diagnostics(m, opts)

View File

@ -100,17 +100,13 @@ class Model(nn.Module):
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(
self, feature: Tensor, feature_lens: Tensor
) -> Tuple[Tensor, Tensor]:
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
logits = self.encoder_proj(encoder_out)
@ -168,9 +164,7 @@ def main():
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

File diff suppressed because it is too large Load Diff

View File

@ -282,9 +282,7 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat(
[state_list[i][-1] for i in range(batch_size)], dim=0
)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
@ -322,9 +320,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(
chunks=batch_size, dim=1
)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
@ -355,9 +351,7 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(
chunks=batch_size, dim=0
)
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
@ -380,11 +374,7 @@ def streaming_forward(
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
@ -404,9 +394,7 @@ def streaming_forward(
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
@ -494,9 +482,7 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
@ -517,9 +503,7 @@ def decode_one_chunk(
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
@ -577,9 +561,7 @@ def decode_dataset(
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(
model=model, batch_size=1, device=device
)
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
@ -649,9 +631,7 @@ def decode_dataset(
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
@ -684,8 +664,7 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -718,9 +697,7 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
@ -760,9 +737,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -789,9 +766,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"

View File

@ -107,9 +107,7 @@ class ConvNeXt(nn.Module):
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand(
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
@ -278,9 +276,7 @@ class Conv2dSubsampling(nn.Module):
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat(
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)
@ -331,7 +327,7 @@ class Conv2dSubsampling(nn.Module):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item() , (x.size(1), x_lens.max())
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
@ -403,8 +399,8 @@ class Conv2dSubsampling(nn.Module):
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(
batch_size, channels, left_pad, freq
).to(device)
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
return cached_embed_left_pad

View File

@ -604,11 +604,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
assert params.use_transducer or params.use_ctc, (
f"At least one of them should be True, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
f"params.use_ctc={params.use_ctc}"
)
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
@ -808,17 +808,16 @@ def compute_loss(
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s if batch_idx_train >= warm_step
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0 if batch_idx_train >= warm_step
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += (
simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
@ -1166,7 +1165,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -981,7 +981,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -746,7 +746,6 @@ def train_one_epoch(
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -966,7 +965,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1019,7 +1018,6 @@ def run(rank, world_size, args):
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs + 1):
scheduler.step_epoch(epoch - 1)
fix_random_seed(params.seed + epoch - 1)
train_dl.sampler.set_epoch(epoch - 1)
@ -1118,7 +1116,6 @@ def scan_pessimistic_batches_for_oom(
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _, _ = compute_loss(
params=params,
model=model,

View File

@ -1164,7 +1164,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -915,7 +915,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -69,7 +69,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer
from icefall import diagnostics, byte_encode, tokenize_by_CJK_char
from icefall import byte_encode, diagnostics, tokenize_by_CJK_char
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -1018,7 +1018,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -905,7 +905,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1126,7 +1126,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -315,6 +315,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -886,7 +886,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -258,6 +258,7 @@ def main():
encoder_session = ort.InferenceSession(
args.onnx_encoder_filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
test_encoder(model, encoder_session)
@ -265,6 +266,7 @@ def main():
decoder_session = ort.InferenceSession(
args.onnx_decoder_filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
test_decoder(model, decoder_session)
@ -272,14 +274,17 @@ def main():
joiner_session = ort.InferenceSession(
args.onnx_joiner_filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
joiner_encoder_proj_session = ort.InferenceSession(
args.onnx_joiner_encoder_proj_filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
joiner_decoder_proj_session = ort.InferenceSession(
args.onnx_joiner_decoder_proj_filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)
test_joiner(
model,

View File

@ -851,7 +851,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -404,6 +404,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -335,6 +335,7 @@ def export_decoder_model_onnx(
vocab_size = decoder_model.decoder.vocab_size
y = torch.zeros(10, context_size, dtype=torch.int64)
decoder_model = torch.jit.script(decoder_model)
torch.onnx.export(
decoder_model,
y,

View File

@ -139,6 +139,7 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()
@ -186,6 +187,7 @@ class OnnxModel:
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -199,6 +201,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -158,12 +158,14 @@ class OnnxModel:
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
@ -177,6 +179,7 @@ class OnnxModel:
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
joiner_meta = self.joiner.get_modelmeta().custom_metadata_map

View File

@ -985,7 +985,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1128,7 +1128,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -1001,7 +1001,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

View File

@ -993,7 +993,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
512
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)

Some files were not shown because too many files have changed in this diff Show More