check fairseq and quantization

This commit is contained in:
Guo Liyong 2022-05-27 12:26:43 +08:00
parent 91432397cf
commit 6bc387cf46
2 changed files with 35 additions and 6 deletions

View File

@ -26,12 +26,23 @@ if [ $stage -eq 0 ]; then
# https://github.com/pytorch/fairseq
# when testing this code:
# commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used.
#
has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)")
if [ $has_fairseq == 'False' ]; then
echo "Please install fairseq before running following stages"
exit 1
fi
# Install quantization toolkit:
# pip install git+https://github.com/danpovey/quantization.git@master
# when testing this code:
# commit c17ffe67aa2e6ca6b6855c50fde812f2eed7870b is used.
has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)")
if [ $has_quantization == 'False' ]; then
echo "Please install quantization before running following stages"
exit 1
fi
echo "Download hubert model."
# Parameters about model.
exp_dir=./pruned_transducer_stateless6/exp/

View File

@ -34,6 +34,15 @@ from icefall.utils import AttributeDict
def _load_hubert_model(params: AttributeDict):
"""
Load the hubert model.
The model loaded is specified by params.hubert_model_dir
and params.teacher_model_id.
Returned model carries hubert,
while processor is responsible to map model's output to human readable transcripts.
"""
cfg_task = OmegaConf.create(
{
"_name": "hubert_pretraining",
@ -130,7 +139,7 @@ class HubertXlargeFineTuned:
def extract_layers_result(
self,
batch: Dict,
) -> Dict[str, torch.Tensor]:
) -> List[torch.Tensor]:
"""
Extract activations from all layers.
"""
@ -154,10 +163,9 @@ class HubertXlargeFineTuned:
features = features.transpose(1, 2)
features = self.w2v_model.layer_norm(features)
if padding_mask is not None:
padding_mask = self.w2v_model.forward_padding_mask(
features, padding_mask
)
padding_mask = self.w2v_model.forward_padding_mask(
features, padding_mask
)
if self.w2v_model.post_extract_proj is not None:
features = self.w2v_model.post_extract_proj(features)
@ -169,6 +177,16 @@ class HubertXlargeFineTuned:
return layer_results
def extract_embedding(self, batch) -> Tuple[torch.tensor, List[int]]:
"""
Eextract embeddings specified by self.params.embedding_layer.
These embeddings could be used to train quantizer
or to extract codebook indexes.
The returned List[int] is valid length of each embedding.
We only want to store codebook indexes related to
these valid embeddings.
"""
supervisions = batch["supervisions"]
cut_list = supervisions["cut"]
assert all(c.start == 0 for c in cut_list)