mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-14 03:24:18 +00:00
check fairseq and quantization
This commit is contained in:
parent
91432397cf
commit
6bc387cf46
@ -26,12 +26,23 @@ if [ $stage -eq 0 ]; then
|
|||||||
# https://github.com/pytorch/fairseq
|
# https://github.com/pytorch/fairseq
|
||||||
# when testing this code:
|
# when testing this code:
|
||||||
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
|
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
|
||||||
#
|
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
|
||||||
|
if [ $has_fairseq == 'False' ]; then
|
||||||
|
echo "Please install fairseq before running following stages"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
# Install quantization toolkit:
|
# Install quantization toolkit:
|
||||||
# pip install git+https://github.com/danpovey/quantization.git@master
|
# pip install git+https://github.com/danpovey/quantization.git@master
|
||||||
# when testing this code:
|
# when testing this code:
|
||||||
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
|
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
|
||||||
|
|
||||||
|
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||||
|
if [ $has_quantization == 'False' ]; then
|
||||||
|
echo "Please install quantization before running following stages"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Download hubert model."
|
echo "Download hubert model."
|
||||||
# Parameters about model.
|
# Parameters about model.
|
||||||
exp_dir=./pruned_transducer_stateless6/exp/
|
exp_dir=./pruned_transducer_stateless6/exp/
|
||||||
|
@ -34,6 +34,15 @@ from icefall.utils import AttributeDict
|
|||||||
|
|
||||||
|
|
||||||
def _load_hubert_model(params: AttributeDict):
|
def _load_hubert_model(params: AttributeDict):
|
||||||
|
"""
|
||||||
|
Load the hubert model.
|
||||||
|
|
||||||
|
The model loaded is specified by params.hubert_model_dir
|
||||||
|
and params.teacher_model_id.
|
||||||
|
|
||||||
|
Returned model carries hubert,
|
||||||
|
while processor is responsible to map model's output to human readable transcripts.
|
||||||
|
"""
|
||||||
cfg_task = OmegaConf.create(
|
cfg_task = OmegaConf.create(
|
||||||
{
|
{
|
||||||
"_name": "hubert_pretraining",
|
"_name": "hubert_pretraining",
|
||||||
@ -130,7 +139,7 @@ class HubertXlargeFineTuned:
|
|||||||
def extract_layers_result(
|
def extract_layers_result(
|
||||||
self,
|
self,
|
||||||
batch: Dict,
|
batch: Dict,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Extract activations from all layers.
|
Extract activations from all layers.
|
||||||
"""
|
"""
|
||||||
@ -154,7 +163,6 @@ class HubertXlargeFineTuned:
|
|||||||
features = features.transpose(1, 2)
|
features = features.transpose(1, 2)
|
||||||
features = self.w2v_model.layer_norm(features)
|
features = self.w2v_model.layer_norm(features)
|
||||||
|
|
||||||
if padding_mask is not None:
|
|
||||||
padding_mask = self.w2v_model.forward_padding_mask(
|
padding_mask = self.w2v_model.forward_padding_mask(
|
||||||
features, padding_mask
|
features, padding_mask
|
||||||
)
|
)
|
||||||
@ -169,6 +177,16 @@ class HubertXlargeFineTuned:
|
|||||||
return layer_results
|
return layer_results
|
||||||
|
|
||||||
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
|
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
|
||||||
|
"""
|
||||||
|
Eextract embeddings specified by self.params.embedding_layer.
|
||||||
|
|
||||||
|
These embeddings could be used to train quantizer
|
||||||
|
or to extract codebook indexes.
|
||||||
|
|
||||||
|
The returned List[int] is valid length of each embedding.
|
||||||
|
We only want to store codebook indexes related to
|
||||||
|
these valid embeddings.
|
||||||
|
"""
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
cut_list = supervisions["cut"]
|
cut_list = supervisions["cut"]
|
||||||
assert all(c.start == 0 for c in cut_list)
|
assert all(c.start == 0 for c in cut_list)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user