from local

This commit is contained in:
dohe0342 2022-12-29 11:32:11 +09:00
parent 6625b97615
commit 8cfc975314
7 changed files with 53 additions and 28 deletions

View File

@ -23,7 +23,7 @@
# To start from scratch, you can
# set stage=0, stop_stage=4, use_extracted_codebook=False
stage=2
stage=0
stop_stage=4
# Set the GPUs available.
@ -35,8 +35,7 @@ stop_stage=4
# export CUDA_VISIBLE_DEVICES="0"
#
# Suppose GPU 2,3,4,5 are available.
#export CUDA_VISIBLE_DEVICES="0,1,2,3"
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
exp_dir=./pruned_transducer_stateless6/exp
mkdir -p $exp_dir
@ -50,7 +49,7 @@ full_libri=False
# "True" -> stage 0 and stage 1 would be skipped,
# and directly download the extracted codebook indexes for distillation
# "False" -> start from scratch
use_extracted_codebook=False
use_extracted_codebook=True
# teacher_model_id can be one of
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
@ -156,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
log "Downloading extracted codebook indexes to $codebook_download_dir"
# Make sure you have git-lfs installed (https://git-lfs.github.com)
# The codebook indexes are generated using lhotse 1.11.0, to avoid
# potential issues, we recommend you to use lhotse version >= 1.11.0
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
if [ "$lhotse_version" == "False" ]; then
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
fi
git lfs install
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
mkdir -p data/vq_fbank
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/

View File

@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module):
def forward(self, x: Tensor) -> Tensor:
if random.random() >= self.balance_prob:
return x
else:
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
return ActivationBalancerFunction.apply(
x,
self.channel_dim,
self.min_positive,
self.max_positive,
self.max_factor / self.balance_prob,
self.min_abs,
self.max_abs,
)
class DoubleSwishFunction(torch.autograd.Function):

View File

@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled(
if not inplace:
model = copy.deepcopy(model)
excluded_patterns = r"self_attn\.(in|out)_proj"
excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
p = re.compile(excluded_patterns)
d = {}

View File

@ -965,7 +965,6 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_transducer_model(params)
logging.info(model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

View File

@ -929,7 +929,6 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_transducer_model(params)
logging.info(model)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -986,11 +985,9 @@ def run(rank, world_size, args):
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
'''
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
'''
return False
# In pruned RNN-T, we require that T >= S
@ -1001,9 +998,8 @@ def run(rank, world_size, args):
# for subsampling
T = ((c.num_frames - 1) // 2 - 1) // 2
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
'''
logging.warning(
f"Exclude cut with ID {c.id} from training. "
f"Number of frames (before subsampling): {c.num_frames}. "
@ -1012,9 +1008,8 @@ def run(rank, world_size, args):
f"Tokens: {tokens}. "
f"Number of tokens: {len(tokens)}"
)
'''
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)

View File

@ -244,10 +244,36 @@ class CodebookIndexExtractor:
)
cuts_vq = load_manifest(vq_manifest_path)
cuts_ori = load_manifest(ori_manifest_path)
cuts_vq = cuts_vq.sort_like(cuts_ori)
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
assert cut_vq.id == cut_ori.id
cut_ori.codebook_indexes = cut_vq.codebook_indexes
assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!"
if set(cuts_vq.ids) == set(cuts_ori.ids):
# IDs match exactly
cuts_vq = cuts_vq.sort_like(cuts_ori)
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id)
cut_ori.codebook_indexes = cut_vq.codebook_indexes
else:
# in case of ID mismatch, remap them
# get the mapping between audio and cut ID
logging
ori_id_map = {}
for id in cuts_ori.ids:
# some text normalization
if "sp" in id:
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
else:
clean_id = "-".join(id.split("-")[:3])
ori_id_map[clean_id] = id
for id in cuts_vq.ids:
if "sp" in id:
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
else:
clean_id = "-".join(id.split("-")[:3])
assert clean_id in ori_id_map, clean_id
cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[
id
].codebook_indexes
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
logging.info(f"Processed {subset}.")