mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
6625b97615
commit
8cfc975314
Binary file not shown.
@ -23,7 +23,7 @@
|
||||
# To start from scratch, you can
|
||||
# set stage=0, stop_stage=4, use_extracted_codebook=False
|
||||
|
||||
stage=2
|
||||
stage=0
|
||||
stop_stage=4
|
||||
|
||||
# Set the GPUs available.
|
||||
@ -35,8 +35,7 @@ stop_stage=4
|
||||
# export CUDA_VISIBLE_DEVICES="0"
|
||||
#
|
||||
# Suppose GPU 2,3,4,5 are available.
|
||||
#export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
# export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
exp_dir=./pruned_transducer_stateless6/exp
|
||||
mkdir -p $exp_dir
|
||||
@ -50,7 +49,7 @@ full_libri=False
|
||||
# "True" -> stage 0 and stage 1 would be skipped,
|
||||
# and directly download the extracted codebook indexes for distillation
|
||||
# "False" -> start from scratch
|
||||
use_extracted_codebook=False
|
||||
use_extracted_codebook=True
|
||||
|
||||
# teacher_model_id can be one of
|
||||
# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use.
|
||||
@ -156,8 +155,14 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
log "Downloading extracted codebook indexes to $codebook_download_dir"
|
||||
# Make sure you have git-lfs installed (https://git-lfs.github.com)
|
||||
# The codebook indexes are generated using lhotse 1.11.0, to avoid
|
||||
# potential issues, we recommend you to use lhotse version >= 1.11.0
|
||||
lhotse_version=$(python3 -c "import lhotse; from packaging import version; print(version.parse(lhotse.version.__version__)>=version.parse('1.11.0'))")
|
||||
if [ "$lhotse_version" == "False" ]; then
|
||||
log "Expecting lhotse >= 1.11.0. This may lead to potential ID mismatch."
|
||||
fi
|
||||
git lfs install
|
||||
git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||
git clone https://huggingface.co/marcoyang/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir
|
||||
|
||||
mkdir -p data/vq_fbank
|
||||
mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/
|
||||
|
||||
@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if random.random() >= self.balance_prob:
|
||||
return x
|
||||
else:
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor / self.balance_prob,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
return ActivationBalancerFunction.apply(
|
||||
x,
|
||||
self.channel_dim,
|
||||
self.min_positive,
|
||||
self.max_positive,
|
||||
self.max_factor / self.balance_prob,
|
||||
self.min_abs,
|
||||
self.max_abs,
|
||||
)
|
||||
|
||||
|
||||
class DoubleSwishFunction(torch.autograd.Function):
|
||||
|
||||
@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled(
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
excluded_patterns = r"self_attn\.(in|out)_proj"
|
||||
excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
|
||||
p = re.compile(excluded_patterns)
|
||||
|
||||
d = {}
|
||||
|
||||
@ -965,7 +965,6 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
logging.info(model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
@ -929,7 +929,6 @@ def run(rank, world_size, args):
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
logging.info(model)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -986,11 +985,9 @@ def run(rank, world_size, args):
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
'''
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
'''
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
@ -1001,9 +998,8 @@ def run(rank, world_size, args):
|
||||
# for subsampling
|
||||
T = ((c.num_frames - 1) // 2 - 1) // 2
|
||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
|
||||
if T < len(tokens):
|
||||
'''
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. "
|
||||
f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
@ -1012,9 +1008,8 @@ def run(rank, world_size, args):
|
||||
f"Tokens: {tokens}. "
|
||||
f"Number of tokens: {len(tokens)}"
|
||||
)
|
||||
'''
|
||||
return False
|
||||
|
||||
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
@ -244,10 +244,36 @@ class CodebookIndexExtractor:
|
||||
)
|
||||
cuts_vq = load_manifest(vq_manifest_path)
|
||||
cuts_ori = load_manifest(ori_manifest_path)
|
||||
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
||||
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
||||
assert cut_vq.id == cut_ori.id
|
||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||
assert len(cuts_vq) == len(cuts_ori), "Cuts should have the same length!"
|
||||
|
||||
if set(cuts_vq.ids) == set(cuts_ori.ids):
|
||||
# IDs match exactly
|
||||
cuts_vq = cuts_vq.sort_like(cuts_ori)
|
||||
for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)):
|
||||
assert cut_vq.id == cut_ori.id, (cut_vq.id, cut_ori.id)
|
||||
cut_ori.codebook_indexes = cut_vq.codebook_indexes
|
||||
else:
|
||||
# in case of ID mismatch, remap them
|
||||
# get the mapping between audio and cut ID
|
||||
logging
|
||||
ori_id_map = {}
|
||||
for id in cuts_ori.ids:
|
||||
# some text normalization
|
||||
if "sp" in id:
|
||||
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
|
||||
else:
|
||||
clean_id = "-".join(id.split("-")[:3])
|
||||
ori_id_map[clean_id] = id
|
||||
|
||||
for id in cuts_vq.ids:
|
||||
if "sp" in id:
|
||||
clean_id = "-".join(id.split("-")[:3]) + "_" + id.split("_")[-1]
|
||||
else:
|
||||
clean_id = "-".join(id.split("-")[:3])
|
||||
assert clean_id in ori_id_map, clean_id
|
||||
cuts_ori[ori_id_map[clean_id]].codebook_indexes = cuts_vq[
|
||||
id
|
||||
].codebook_indexes
|
||||
|
||||
CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path)
|
||||
logging.info(f"Processed {subset}.")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user