Merge remote-tracking branch 'k2-fsa/master'

This commit is contained in:
yaozengwei 2022-05-07 11:07:48 +08:00
commit ecfb3e9c26
6 changed files with 26 additions and 25 deletions

View File

@ -177,8 +177,8 @@ def post_processing(
) -> List[Tuple[List[str], List[str]]]: ) -> List[Tuple[List[str], List[str]]]:
new_results = [] new_results = []
for ref, hyp in results: for ref, hyp in results:
new_ref = asr_text_post_processing(" ".join(ref)) new_ref = asr_text_post_processing(" ".join(ref)).split()
new_hyp = asr_text_post_processing(" ".join(hyp)) new_hyp = asr_text_post_processing(" ".join(hyp)).split()
new_results.append((new_ref, new_hyp)) new_results.append((new_ref, new_hyp))
return new_results return new_results

View File

@ -276,7 +276,7 @@ def greedy_search(
context_size = model.decoder.context_size context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64 [blank_id] * context_size, device=device, dtype=torch.int64
@ -350,7 +350,7 @@ def greedy_search_batch(
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device device = next(model.parameters()).device
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
T = encoder_out.size(1) T = encoder_out.size(1)
@ -580,7 +580,7 @@ def modified_beam_search(
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)] B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size): for i in range(batch_size):
B[i].add( B[i].add(
@ -705,7 +705,7 @@ def _deprecated_modified_beam_search(
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
T = encoder_out.size(1) T = encoder_out.size(1)
@ -813,7 +813,7 @@ def beam_search(
unk_id = getattr(model, "unk_id", blank_id) unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, [blank_id] * context_size,

View File

@ -1001,7 +1001,7 @@ def scan_pessimistic_batches_for_oom(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
except RuntimeError as e: except Exception as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):
logging.error( logging.error(
"Your GPU ran out of memory with the current " "Your GPU ran out of memory with the current "
@ -1010,7 +1010,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, sp=sp) display_and_save_batch(batch, params=params, sp=sp)
raise raise

View File

@ -250,7 +250,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
device = model.device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
@ -560,7 +560,6 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)

View File

@ -125,8 +125,8 @@ def get_parser():
"--start-epoch", "--start-epoch",
type=int, type=int,
default=1, default=1,
help="""Resume training from from this epoch. help="""Resume training from this epoch. It should be positive.
If it is positive, it will load checkpoint from If larger than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt exp-dir/epoch-{start_epoch-1}.pt
""", """,
) )
@ -479,7 +479,7 @@ def load_checkpoint_if_available(
def save_checkpoint( def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
@ -529,7 +529,7 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
@ -553,7 +553,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -609,7 +613,7 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
@ -643,7 +647,7 @@ def compute_validation_loss(
def train_one_epoch( def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
@ -857,6 +861,7 @@ def run(rank, world_size, args):
# model_avg is only used with rank 0 # model_avg is only used with rank 0
model_avg = copy.deepcopy(model) model_avg = copy.deepcopy(model)
assert params.start_epoch > 0, params.start_epoch
checkpoints = load_checkpoint_if_available( checkpoints = load_checkpoint_if_available(
params=params, model=model, model_avg=model_avg params=params, model=model, model_avg=model_avg
) )
@ -865,11 +870,6 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device
if rank == 0:
model_avg.to(device)
model_avg.device = device
optimizer = Eve(model.parameters(), lr=params.initial_lr) optimizer = Eve(model.parameters(), lr=params.initial_lr)
@ -990,7 +990,7 @@ def run(rank, world_size, args):
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: nn.Module, model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,

View File

@ -467,5 +467,7 @@ def average_state_dict(
uniqued_names = list(uniqued.values()) uniqued_names = list(uniqued.values())
for k in uniqued_names: for k in uniqued_names:
state_dict_1[k] *= weight_1 state_dict_1[k] *= weight_1
state_dict_1[k] += state_dict_2[k] * weight_2 state_dict_1[k] += (
state_dict_2[k].to(device=state_dict_1[k].device) * weight_2
)
state_dict_1[k] *= scaling_factor state_dict_1[k] *= scaling_factor