mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-13 19:14:20 +00:00
check fairseq and quantization
This commit is contained in:
parent
91432397cf
commit
6bc387cf46
@ -26,12 +26,23 @@ if [ $stage -eq 0 ]; then
|
||||
# https://github.com/pytorch/fairseq
|
||||
# when testing this code:
|
||||
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
|
||||
#
|
||||
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
|
||||
if [ $has_fairseq == 'False' ]; then
|
||||
echo "Please install fairseq before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Install quantization toolkit:
|
||||
# pip install git+https://github.com/danpovey/quantization.git@master
|
||||
# when testing this code:
|
||||
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
|
||||
|
||||
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
|
||||
if [ $has_quantization == 'False' ]; then
|
||||
echo "Please install quantization before running following stages"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Download hubert model."
|
||||
# Parameters about model.
|
||||
exp_dir=./pruned_transducer_stateless6/exp/
|
||||
|
@ -34,6 +34,15 @@ from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def _load_hubert_model(params: AttributeDict):
|
||||
"""
|
||||
Load the hubert model.
|
||||
|
||||
The model loaded is specified by params.hubert_model_dir
|
||||
and params.teacher_model_id.
|
||||
|
||||
Returned model carries hubert,
|
||||
while processor is responsible to map model's output to human readable transcripts.
|
||||
"""
|
||||
cfg_task = OmegaConf.create(
|
||||
{
|
||||
"_name": "hubert_pretraining",
|
||||
@ -130,7 +139,7 @@ class HubertXlargeFineTuned:
|
||||
def extract_layers_result(
|
||||
self,
|
||||
batch: Dict,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Extract activations from all layers.
|
||||
"""
|
||||
@ -154,10 +163,9 @@ class HubertXlargeFineTuned:
|
||||
features = features.transpose(1, 2)
|
||||
features = self.w2v_model.layer_norm(features)
|
||||
|
||||
if padding_mask is not None:
|
||||
padding_mask = self.w2v_model.forward_padding_mask(
|
||||
features, padding_mask
|
||||
)
|
||||
padding_mask = self.w2v_model.forward_padding_mask(
|
||||
features, padding_mask
|
||||
)
|
||||
|
||||
if self.w2v_model.post_extract_proj is not None:
|
||||
features = self.w2v_model.post_extract_proj(features)
|
||||
@ -169,6 +177,16 @@ class HubertXlargeFineTuned:
|
||||
return layer_results
|
||||
|
||||
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
|
||||
"""
|
||||
Eextract embeddings specified by self.params.embedding_layer.
|
||||
|
||||
These embeddings could be used to train quantizer
|
||||
or to extract codebook indexes.
|
||||
|
||||
The returned List[int] is valid length of each embedding.
|
||||
We only want to store codebook indexes related to
|
||||
these valid embeddings.
|
||||
"""
|
||||
supervisions = batch["supervisions"]
|
||||
cut_list = supervisions["cut"]
|
||||
assert all(c.start == 0 for c in cut_list)
|
||||
|
Loading…
x
Reference in New Issue
Block a user