mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge remote-tracking branch 'k2-fsa/master'
This commit is contained in:
commit
ecfb3e9c26
@ -177,8 +177,8 @@ def post_processing(
|
||||
) -> List[Tuple[List[str], List[str]]]:
|
||||
new_results = []
|
||||
for ref, hyp in results:
|
||||
new_ref = asr_text_post_processing(" ".join(ref))
|
||||
new_hyp = asr_text_post_processing(" ".join(hyp))
|
||||
new_ref = asr_text_post_processing(" ".join(ref)).split()
|
||||
new_hyp = asr_text_post_processing(" ".join(hyp)).split()
|
||||
new_results.append((new_ref, new_hyp))
|
||||
return new_results
|
||||
|
||||
|
@ -276,7 +276,7 @@ def greedy_search(
|
||||
context_size = model.decoder.context_size
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
|
||||
device = model.device
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size, device=device, dtype=torch.int64
|
||||
@ -350,7 +350,7 @@ def greedy_search_batch(
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
device = model.device
|
||||
device = next(model.parameters()).device
|
||||
|
||||
batch_size = encoder_out.size(0)
|
||||
T = encoder_out.size(1)
|
||||
@ -580,7 +580,7 @@ def modified_beam_search(
|
||||
blank_id = model.decoder.blank_id
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
device = next(model.parameters()).device
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
for i in range(batch_size):
|
||||
B[i].add(
|
||||
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
device = next(model.parameters()).device
|
||||
|
||||
T = encoder_out.size(1)
|
||||
|
||||
@ -813,7 +813,7 @@ def beam_search(
|
||||
unk_id = getattr(model, "unk_id", blank_id)
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
device = next(model.parameters()).device
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[blank_id] * context_size,
|
||||
|
@ -1001,7 +1001,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
except RuntimeError as e:
|
||||
except Exception as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
@ -1010,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
display_and_save_batch(batch, params=params, sp=sp)
|
||||
raise
|
||||
|
||||
|
||||
|
@ -250,7 +250,7 @@ def decode_one_batch(
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = model.device
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
@ -560,7 +560,6 @@ def main():
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
@ -125,8 +125,8 @@ def get_parser():
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
help="""Resume training from this epoch. It should be positive.
|
||||
If larger than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
@ -479,7 +479,7 @@ def load_checkpoint_if_available(
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
@ -529,7 +529,7 @@ def save_checkpoint(
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
@ -553,7 +553,11 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -609,7 +613,7 @@ def compute_loss(
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
@ -643,7 +647,7 @@ def compute_validation_loss(
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
@ -857,6 +861,7 @@ def run(rank, world_size, args):
|
||||
# model_avg is only used with rank 0
|
||||
model_avg = copy.deepcopy(model)
|
||||
|
||||
assert params.start_epoch > 0, params.start_epoch
|
||||
checkpoints = load_checkpoint_if_available(
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
@ -865,11 +870,6 @@ def run(rank, world_size, args):
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank])
|
||||
model.device = device
|
||||
|
||||
if rank == 0:
|
||||
model_avg.to(device)
|
||||
model_avg.device = device
|
||||
|
||||
optimizer = Eve(model.parameters(), lr=params.initial_lr)
|
||||
|
||||
@ -990,7 +990,7 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: nn.Module,
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
|
@ -467,5 +467,7 @@ def average_state_dict(
|
||||
uniqued_names = list(uniqued.values())
|
||||
for k in uniqued_names:
|
||||
state_dict_1[k] *= weight_1
|
||||
state_dict_1[k] += state_dict_2[k] * weight_2
|
||||
state_dict_1[k] += (
|
||||
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
|
||||
)
|
||||
state_dict_1[k] *= scaling_factor
|
||||
|
Loading…
x
Reference in New Issue
Block a user