mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge remote-tracking branch 'k2-fsa/master'
This commit is contained in:
commit
ecfb3e9c26
@ -177,8 +177,8 @@ def post_processing(
|
|||||||
) -> List[Tuple[List[str], List[str]]]:
|
) -> List[Tuple[List[str], List[str]]]:
|
||||||
new_results = []
|
new_results = []
|
||||||
for ref, hyp in results:
|
for ref, hyp in results:
|
||||||
new_ref = asr_text_post_processing(" ".join(ref))
|
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||||
new_hyp = asr_text_post_processing(" ".join(hyp))
|
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||||
new_results.append((new_ref, new_hyp))
|
new_results.append((new_ref, new_hyp))
|
||||||
return new_results
|
return new_results
|
||||||
|
|
||||||
|
@ -276,7 +276,7 @@ def greedy_search(
|
|||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||||
@ -350,7 +350,7 @@ def greedy_search_batch(
|
|||||||
assert encoder_out.ndim == 3
|
assert encoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
batch_size = encoder_out.size(0)
|
batch_size = encoder_out.size(0)
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
@ -580,7 +580,7 @@ def modified_beam_search(
|
|||||||
blank_id = model.decoder.blank_id
|
blank_id = model.decoder.blank_id
|
||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
B = [HypothesisList() for _ in range(batch_size)]
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
B[i].add(
|
B[i].add(
|
||||||
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
|
|||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
T = encoder_out.size(1)
|
T = encoder_out.size(1)
|
||||||
|
|
||||||
@ -813,7 +813,7 @@ def beam_search(
|
|||||||
unk_id = getattr(model, "unk_id", blank_id)
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
context_size = model.decoder.context_size
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
[blank_id] * context_size,
|
[blank_id] * context_size,
|
||||||
|
@ -1001,7 +1001,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
logging.error(
|
logging.error(
|
||||||
"Your GPU ran out of memory with the current "
|
"Your GPU ran out of memory with the current "
|
||||||
@ -1010,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
display_and_save_batch(batch, params=params, sp=sp)
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ def decode_one_batch(
|
|||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict.
|
the returned dict.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = next(model.parameters()).device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
|
|
||||||
@ -560,7 +560,6 @@ def main():
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.device = device
|
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
@ -125,8 +125,8 @@ def get_parser():
|
|||||||
"--start-epoch",
|
"--start-epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
help="""Resume training from from this epoch.
|
help="""Resume training from this epoch. It should be positive.
|
||||||
If it is positive, it will load checkpoint from
|
If larger than 1, it will load checkpoint from
|
||||||
exp-dir/epoch-{start_epoch-1}.pt
|
exp-dir/epoch-{start_epoch-1}.pt
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -479,7 +479,7 @@ def load_checkpoint_if_available(
|
|||||||
|
|
||||||
def save_checkpoint(
|
def save_checkpoint(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
@ -529,7 +529,7 @@ def save_checkpoint(
|
|||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
@ -553,7 +553,11 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = (
|
||||||
|
model.device
|
||||||
|
if isinstance(model, DDP)
|
||||||
|
else next(model.parameters()).device
|
||||||
|
)
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -609,7 +613,7 @@ def compute_loss(
|
|||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -643,7 +647,7 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
scheduler: LRSchedulerType,
|
scheduler: LRSchedulerType,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
@ -857,6 +861,7 @@ def run(rank, world_size, args):
|
|||||||
# model_avg is only used with rank 0
|
# model_avg is only used with rank 0
|
||||||
model_avg = copy.deepcopy(model)
|
model_avg = copy.deepcopy(model)
|
||||||
|
|
||||||
|
assert params.start_epoch > 0, params.start_epoch
|
||||||
checkpoints = load_checkpoint_if_available(
|
checkpoints = load_checkpoint_if_available(
|
||||||
params=params, model=model, model_avg=model_avg
|
params=params, model=model, model_avg=model_avg
|
||||||
)
|
)
|
||||||
@ -865,11 +870,6 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
model.device = device
|
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
model_avg.to(device)
|
|
||||||
model_avg.device = device
|
|
||||||
|
|
||||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||||
|
|
||||||
@ -990,7 +990,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: nn.Module,
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
|
@ -467,5 +467,7 @@ def average_state_dict(
|
|||||||
uniqued_names = list(uniqued.values())
|
uniqued_names = list(uniqued.values())
|
||||||
for k in uniqued_names:
|
for k in uniqued_names:
|
||||||
state_dict_1[k] *= weight_1
|
state_dict_1[k] *= weight_1
|
||||||
state_dict_1[k] += state_dict_2[k] * weight_2
|
state_dict_1[k] += (
|
||||||
|
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
||||||
|
)
|
||||||
state_dict_1[k] *= scaling_factor
|
state_dict_1[k] *= scaling_factor
|
||||||
|
Loading…
x
Reference in New Issue
Block a user