decode and adapt method is initialized

This commit is contained in:
j-pong 2023-06-13 10:08:55 +09:00
parent 058c72b7d2
commit 3445e63962
15 changed files with 1202 additions and 33 deletions

Binary file not shown.

Binary file not shown.

View File

@ -13,6 +13,39 @@ from lhotse import (
validate_recordings_and_supervisions,
)
from lhotse.utils import Pathlike, safe_extract
from tqdm.auto import tqdm
def tqdm_urlretrieve_hook(t):
"""Wraps tqdm instance.
Don't forget to close() or __exit__()
the tqdm instance once you're done with it (easiest using `with` syntax).
Example
-------
>>> from urllib.request import urlretrieve
>>> with tqdm(...) as t:
... reporthook = tqdm_urlretrieve_hook(t)
... urlretrieve(..., reporthook=reporthook)
Source: https://github.com/tqdm/tqdm/blob/master/examples/tqdm_wget.py
"""
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
"""
b : int, optional
Number of blocks transferred so far [default: 1].
bsize : int, optional
Size of each block (in tqdm units) [default: 1].
tsize : int, optional
Total size (in tqdm units). If [default: None] or -1,
remains unchanged.
"""
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed
return update_to
def urlretrieve_progress(url, filename=None, data=None, desc=None):
"""

View File

@ -228,14 +228,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Generate pseudo label"
for spk in {0..10}; do
spk_id=${spk#*$dest\/}
echo $spk_id
./pseudo.sh $spk_id $subset
done
fi
# if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
# log "Stage 10: Generate pseudo label"
# for spk in {0..10}; do
# spk_id=${spk#*$dest\/}
# echo $spk_id
# ./pseudo.sh $spk_id $subset
# done
# fi
: <<'END'

View File

@ -39,6 +39,8 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool
from sampling import SingleUttSampler
class TedLiumAsrDataModule:
"""
@ -355,11 +357,20 @@ class TedLiumAsrDataModule:
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
test_sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
else:
logging.info("Using SingleUttSampler.")
test_sampler = SingleUttSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
@ -390,5 +401,5 @@ class TedLiumAsrDataModule:
@lru_cache()
def user_test_cuts(self, spk_id) -> CutSet:
logging.info("About to get test cuts")
logging.info(f"About to get test cuts : {spk_id}")
return load_manifest_lazy(self.args.manifest_dir / f"tedlium_cuts_test_{spk_id}.jsonl.gz")

View File

@ -79,26 +79,26 @@ from icefall.utils import (
)
import fairseq
from data2vec_audio import LoRAModule
# from data2vec_audio import LoRAModule
LOG_EPS = math.log(1e-10)
class LoRAHook():
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
self.lora = LoRAModule(
embedding_dim=768,
rank=6,
lora_alpha=10000,
)
def hook_fn(self, module, input, output):
lora_out = self.lora(input[0])
output += lora_out
# class LoRAHook():
# def __init__(self, module):
# self.hook = module.register_forward_hook(self.hook_fn)
# self.lora = LoRAModule(
# embedding_dim=768,
# rank=6,
# lora_alpha=10000,
# )
# def hook_fn(self, module, input, output):
# lora_out = self.lora(input[0])
# output += lora_out
def save_checkpoint(self, i, iter_, save_dir):
if isinstance(self.lora, DDP):
lora = self.lora.module
torch.save(lora.state_dict(), f"{save_dir}/lora_{iter_}_{i}.pt")
# def save_checkpoint(self, i, iter_, save_dir):
# if isinstance(self.lora, DDP):
# lora = self.lora.module
# torch.save(lora.state_dict(), f"{save_dir}/lora_{iter_}_{i}.pt")
def get_parser():
@ -304,6 +304,21 @@ def get_parser():
return parser
from typing import Any, Dict, Optional, Tuple, Union
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
save_args,
)
import warnings
def decode_one_batch(
params: AttributeDict,
@ -479,6 +494,8 @@ def decode_one_batch(
)
hyps.append(sp.decode(hyp).split())
print(hyps)
exit()
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
@ -536,7 +553,7 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
log_interval = 1
else:
log_interval = 20
@ -821,10 +838,12 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
args.bucketing_sampler= False
tedlium = TedLiumAsrDataModule(args)
#valid_cuts = tedlium.dev_cuts()
test_cuts = tedlium.user_test_cuts(spk_id=params.spk_id)
#valid_dl = tedlium.test_dataloaders(valid_cuts)
test_cuts = tedlium.user_test_cuts(spk_id=params.spk_id)
test_dl = tedlium.test_dataloaders(test_cuts)
#test_sets = ['dev', 'test']
@ -844,7 +863,7 @@ def main():
save_results(
params=params,
test_set_name=test_set,
test_set_name=test_set + str(params.spk_id),
results_dict=results_dict,
)

View File

@ -0,0 +1,882 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import TedLiumAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train_tta import add_model_arguments, add_rep_arguments, get_params, get_transducer_model
#from prompt_tuning import add_model_arguments, add_rep_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
import fairseq
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--model-name",
type=str,
default="",
help="""It specifies the model file name to use for decoding.""",
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=64,
help="""Used only when --decoding-method is
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--res-name",
type=str,
)
add_model_arguments(parser)
add_rep_arguments(parser)
return parser
from typing import Any, Dict, Optional, Tuple, Union
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
save_args,
)
import warnings
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 2 or feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
#feature_lens = supervisions["num_frames"].to(device)
if feature.ndim == 2:
feature_lens = []
for supervision in supervisions['cut']:
try: feature_lens.append(supervision.tracks[0].cut.recording.num_samples)
except: feature_lens.append(supervision.recording.num_samples)
feature_lens = torch.tensor(feature_lens)
elif feature.ndim == 3:
feature_lens = supervisions["num_frames"].to(device)
if params.simulate_streaming:
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens, prompt=model.prompt)
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
def decode_and_adapt(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
num_iter: int
) -> Tuple[Tensor, MetricsTracker]:
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 2 or feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
if feature.ndim == 2:
feature_lens = []
for supervision in supervisions['cut']:
try: feature_lens.append(supervision.tracks[0].cut.recording.num_samples)
except: feature_lens.append(supervision.recording.num_samples)
feature_lens = torch.tensor(feature_lens)
elif feature.ndim == 3:
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
texts = [text.upper() for text in texts]
token_ids = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(token_ids).to(device)
model.train()
for i in range(num_iter):
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_output = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.ctc_loss_scale > 0:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
supervision_segments, token_ids = encode_supervisions(
supervisions,
subsampling_factor=params.subsampling_factor,
token_ids=token_ids,
)
# Works with a BPE model
decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction="sum",
use_double_scores=params.use_double_scores,
)
assert ctc_loss.requires_grad == is_training
loss += params.ctc_loss_scale * ctc_loss
# self.adapted_model_losses.append(loss.item())
# self.adapted_models.append(self.copy_model_and_optimizer(self.models[0]))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_output = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 1
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
model.eval()
texts = batch["supervisions"]["text"]
texts = [text.upper() for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
# replace the supervision to pseudo labels
batch["supervision"]["text"] = "".join(hyps_dict[params.decoding_method] )
# augment the single utterance (augmentation automatically excued in d2v model)
batch["intputs"] = batch["intputs"].reapeat(4, 1)
decode_and_adapt(params, model, sp, batch, is_training=True, num_iter=10)
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
spk = None
wer = None
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
spk = str(test_set_name)
wer = str(val)
logging.info(s)
with open(f'./{params.res_name}.txt', 'a') as f:
f.write(f"{spk} {wer}\n")
@torch.no_grad()
def main():
parser = get_parser()
TedLiumAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
load_checkpoint(f"{params.exp_dir}/{params.model_name}", model)
model.to(device)
model.eval()
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
args.bucketing_sampler= False
tedlium = TedLiumAsrDataModule(args)
#valid_cuts = tedlium.dev_cuts()
#valid_dl = tedlium.test_dataloaders(valid_cuts)
test_cuts = tedlium.user_test_cuts(spk_id=params.spk_id)
test_dl = tedlium.test_dataloaders(test_cuts)
#test_sets = ['dev', 'test']
#test_dl = [valid_dl, test_dl]
test_sets = ['test']
test_dl = [test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set + str(params.spk_id),
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,224 @@
import warnings
from typing import Any, Dict, Optional
from lhotse import CutSet, Seconds
from lhotse.dataset.sampling.base import CutSampler, TimeConstraint
from lhotse.dataset.sampling.data_source import DataSource
class SingleUttSampler(CutSampler):
"""
Samples cuts from a CutSet to satisfy the input constraints.
It behaves like an iterable that yields lists of strings (cut IDs).
When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified,
the batch size is dynamic.
Exactly zero or one of those constraints can be specified.
Padding required to collate the batch does not contribute to max frames/samples/duration.
Example usage::
>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> sampler = SimpleCutSampler(cuts, shuffle=True)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(
self,
cuts: CutSet,
max_duration: Seconds = None,
max_cuts: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
"""
SimpleCutSampler's constructor.
:param cuts: the ``CutSet`` to sample data from.
:param max_duration: The maximum total recording duration from ``cuts``.
:param max_cuts: The maximum number of cuts sampled to form a mini-batch.
By default, this constraint is off.
:param shuffle: When ``True``, the cuts will be shuffled at the start of iteration.
Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.:
`for epoch in range(10): for batch in dataset: ...` as every epoch will see a
different cuts order.
:param drop_last: When ``True``, the last batch is dropped if it's incomplete.
:param world_size: Total number of distributed nodes. We will try to infer it by default.
:param rank: Index of distributed node. We will try to infer it by default.
:param seed: Random seed used to consistently shuffle the dataset across different processes.
"""
super().__init__(
drop_last=drop_last,
shuffle=shuffle,
world_size=world_size,
rank=rank,
seed=seed,
)
self.data_source = DataSource(cuts)
self.time_constraint = TimeConstraint(
max_duration=max_duration,
max_cuts=max_cuts,
)
@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).
Not available when the CutSet is read in lazy mode (returns None).
"""
return self.data_source.remaining_duration
@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
return self.data_source.remaining_cuts
@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
if self.data_source.is_lazy:
return None
return len(self.data_source)
def state_dict(self) -> Dict[str, Any]:
"""
Return the current state of the sampler in a state_dict.
Together with ``load_state_dict()``, this can be used to restore the
training loop's state to the one stored in the state_dict.
"""
state_dict = super().state_dict()
state_dict.update(
{
"time_constraint": self.time_constraint.state_dict(),
}
)
return state_dict
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Restore the state of the sampler that is described in a state_dict.
This will result in the sampler yielding batches from where the previous training left it off.
.. caution::
The samplers are expected to be initialized with the same CutSets,
but this is not explicitly checked anywhere.
.. caution::
The input ``state_dict`` is being mutated: we remove each consumed key, and expect
it to be empty at the end of loading. If you don't want this behavior, pass a copy
inside of this function (e.g., using ``import deepcopy``).
.. note::
For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be
handled in ``__iter__`` to make it avoid resetting the just-restored state (only once).
"""
time_constraint = TimeConstraint(**state_dict.pop("time_constraint"))
if self.time_constraint != time_constraint:
warnings.warn(
"SimpleCutSampler.load_state_dict(): Inconsistent time_constraint:\n"
f"expected {self.time_constraint}\n"
f"received {time_constraint}\n"
f"We will overwrite the settings with the received state_dict."
)
self.time_constraint = time_constraint
super().load_state_dict(state_dict)
# Restore the data source's state
if self.shuffle:
self.data_source.shuffle(self.seed + self.epoch)
self.data_source.fast_forward(self.diagnostics.current_epoch_stats.total_cuts)
def __iter__(self) -> "SimpleCutSampler":
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
"""
# Restored state with load_state_dict()? Skip resetting only this once.
if self._just_restored_state:
return self
# Why reset the current epoch?
# Either we are iterating the epoch for the first time and it's a no-op,
# or we are iterating the same epoch again, in which case setting more steps
# than are actually available per epoch would have broken the checkpoint restoration.
self.diagnostics.reset_current_epoch()
# Reset the state to the beginning of the epoch.
if self.shuffle:
self.data_source.shuffle(self.seed + self.epoch)
iter(self.data_source)
return self
def _next_batch(self) -> CutSet:
# Keep iterating the underlying CutSet as long as we hit or exceed the constraints
# provided by user (the max number of frames or max number of cuts).
# Note: no actual data is loaded into memory yet because the manifests contain all the metadata
# required to do this operation.
self.time_constraint.reset()
cuts = []
while True:
# Check that we have not reached the end of the dataset.
try:
# If this doesn't raise (typical case), it's not the end: keep processing.
next_cut = next(self.data_source)
except StopIteration:
# No more cuts to sample from: if we have a partial batch,
# we may output it, unless the user requested to drop it.
# We also check if the batch is "almost there" to override drop_last.
if cuts and (
not self.drop_last or self.time_constraint.close_to_exceeding()
):
# We have a partial batch and we can return it.
return CutSet.from_cuts(cuts)
else:
# There is nothing more to return or it's discarded:
# signal the iteration code to stop.
self.diagnostics.discard(cuts)
raise StopIteration()
# Check whether the cut we're about to sample satisfies optional user-requested predicate.
if not self._filter_fn(next_cut):
# No - try another one.
self.diagnostics.discard_single(next_cut)
continue
# Track the duration/frames/etc. constraints.
self.time_constraint.add(next_cut)
cuts.append(next_cut)
break
# # Did we exceed the max_frames and max_cuts constraints?
# if not self.time_constraint.exceeded():
# # No - add the next cut to the batch, and keep trying.
# cuts.append(next_cut)
# else:
# # Yes. Do we have at least one cut in the batch?
# if cuts:
# # Yes. Return the batch, but keep the currently drawn cut for later.
# self.data_source.take_back(next_cut)
# break
# else:
# # No. We'll warn the user that the constrains might be too tight,
# # and return the cut anyway.
# warnings.warn(
# "The first cut drawn in batch collection violates "
# "the max_frames, max_cuts, or max_duration constraints - "
# "we'll return it anyway. "
# "Consider increasing max_frames/max_cuts/max_duration."
# )
# cuts.append(next_cut)
return CutSet.from_cuts(cuts)