mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge branch 'k2-fsa:master' into fix/multi-zh-en
This commit is contained in:
commit
0157687925
8
.github/workflows/style_check.yml
vendored
8
.github/workflows/style_check.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0
|
python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 isort==5.10.1
|
||||||
# Click issue fixed in https://github.com/psf/black/pull/2966
|
# Click issue fixed in https://github.com/psf/black/pull/2966
|
||||||
|
|
||||||
- name: Run flake8
|
- name: Run flake8
|
||||||
@ -67,3 +67,9 @@ jobs:
|
|||||||
working-directory: ${{github.workspace}}
|
working-directory: ${{github.workspace}}
|
||||||
run: |
|
run: |
|
||||||
black --check --diff .
|
black --check --diff .
|
||||||
|
|
||||||
|
- name: Run isort
|
||||||
|
shell: bash
|
||||||
|
working-directory: ${{github.workspace}}
|
||||||
|
run: |
|
||||||
|
isort --check --diff .
|
||||||
|
@ -26,7 +26,7 @@ repos:
|
|||||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.11.5
|
rev: 5.10.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: ["--profile=black"]
|
args: ["--profile=black"]
|
||||||
|
@ -79,10 +79,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -70,9 +70,9 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -23,6 +23,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from lhotse import CutSet, SupervisionSegment
|
from lhotse import CutSet, SupervisionSegment
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
# Similar text filtering and normalization procedure as in:
|
# Similar text filtering and normalization procedure as in:
|
||||||
|
@ -76,6 +76,7 @@ from beam_search import (
|
|||||||
)
|
)
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -42,12 +42,10 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import GigaSpeechAsrDataModule
|
from asr_datamodule import GigaSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import keywords_search
|
||||||
keywords_search,
|
from lhotse.cut import Cut
|
||||||
)
|
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from lhotse.cut import Cut
|
|
||||||
from icefall import ContextGraph
|
from icefall import ContextGraph
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -76,6 +76,20 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
add_training_arguments,
|
||||||
|
compute_loss,
|
||||||
|
compute_validation_loss,
|
||||||
|
display_and_save_batch,
|
||||||
|
get_adjusted_batch_count,
|
||||||
|
get_model,
|
||||||
|
get_params,
|
||||||
|
load_checkpoint_if_available,
|
||||||
|
save_checkpoint,
|
||||||
|
scan_pessimistic_batches_for_oom,
|
||||||
|
set_batch_count,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import remove_checkpoints
|
from icefall.checkpoint import remove_checkpoints
|
||||||
@ -95,21 +109,6 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
from train import (
|
|
||||||
add_model_arguments,
|
|
||||||
add_training_arguments,
|
|
||||||
compute_loss,
|
|
||||||
compute_validation_loss,
|
|
||||||
display_and_save_batch,
|
|
||||||
get_adjusted_batch_count,
|
|
||||||
get_model,
|
|
||||||
get_params,
|
|
||||||
load_checkpoint_if_available,
|
|
||||||
save_checkpoint,
|
|
||||||
scan_pessimistic_batches_for_oom,
|
|
||||||
set_batch_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,8 +24,7 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from train import get_ctc_model, get_params
|
||||||
from train import get_params, get_ctc_model
|
|
||||||
|
|
||||||
|
|
||||||
def test_model():
|
def test_model():
|
||||||
|
@ -59,9 +59,9 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
||||||
from emformer import Emformer
|
from emformer import Emformer
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from do_not_use_it_directly import add_model_arguments, get_params, get_transducer_model
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -39,7 +39,7 @@ Usage of this script:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
@ -47,7 +47,6 @@ import torch
|
|||||||
import torchaudio
|
import torchaudio
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from typing import Optional, List
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -31,28 +31,28 @@ https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stat
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
fast_beam_search_one_best,
|
fast_beam_search_one_best,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
|
||||||
from lhotse import CutSet, load_manifest_lazy
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.supervision import AlignmentItem
|
|
||||||
from lhotse.serialization import SequentialJsonlWriter
|
from lhotse.serialization import SequentialJsonlWriter
|
||||||
|
from lhotse.supervision import AlignmentItem
|
||||||
|
|
||||||
|
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -73,12 +73,11 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
from icefall import is_module_available
|
from icefall import is_module_available
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -22,11 +22,12 @@ Usage: ./pruned_transducer_stateless/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from train import get_encoder_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -78,10 +78,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -76,8 +76,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import AsrDataModule
|
from asr_datamodule import AsrDataModule
|
||||||
from librispeech import LibriSpeech
|
from librispeech import LibriSpeech
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless4/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import BasicNorm, DoubleSwish
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from scaling import BasicNorm, DoubleSwish
|
|
||||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -82,8 +82,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||||
|
|
||||||
# The force alignment problem can be formulated as finding
|
# The force alignment problem can be formulated as finding
|
||||||
|
@ -107,9 +107,6 @@ import k2
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# from asr_datamodule import LibriSpeechAsrDataModule
|
|
||||||
from gigaspeech import GigaSpeechAsrDataModule
|
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
@ -120,6 +117,9 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from gigaspeech import GigaSpeechAsrDataModule
|
||||||
from gigaspeech_scoring import asr_text_post_processing
|
from gigaspeech_scoring import asr_text_post_processing
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
@ -65,16 +65,15 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
average_checkpoints_with_averaged_model,
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -22,15 +22,15 @@ Usage: ./pruned_transducer_stateless7/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
from scaling import BasicNorm, DoubleSwish
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from train import add_model_arguments, get_encoder_model, get_joiner_model, get_params
|
||||||
|
|
||||||
from icefall.profiler import get_model_profile
|
from icefall.profiler import get_model_profile
|
||||||
from scaling import BasicNorm, DoubleSwish
|
|
||||||
from train import get_encoder_model, get_joiner_model, add_model_arguments, get_params
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -75,8 +75,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ To run this file, do:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
@ -118,8 +118,8 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -18,10 +18,7 @@ from typing import List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling import (
|
from scaling import ActivationBalancer, ScaledConv1d
|
||||||
ActivationBalancer,
|
|
||||||
ScaledConv1d,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LConv(nn.Module):
|
class LConv(nn.Module):
|
||||||
|
@ -52,7 +52,7 @@ import onnxruntime as ort
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
import ncnn
|
import ncnn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
layer_list = []
|
layer_list = []
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ import ncnn
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature
|
||||||
|
|
||||||
from ncnn_custom_layer import RegisterCustomLayers
|
from ncnn_custom_layer import RegisterCustomLayers
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import pprint
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
import pprint
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
@ -88,7 +88,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -22,9 +22,9 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
|
@ -22,24 +22,24 @@ Usage: ./zipformer/my_profile.py
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
from torch import Tensor, nn
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
from icefall.profiler import get_model_profile
|
|
||||||
from scaling import BiasNorm
|
from scaling import BiasNorm
|
||||||
|
from torch import Tensor, nn
|
||||||
from train import (
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
get_encoder_embed,
|
get_encoder_embed,
|
||||||
get_encoder_model,
|
get_encoder_model,
|
||||||
get_joiner_model,
|
get_joiner_model,
|
||||||
add_model_arguments,
|
|
||||||
get_params,
|
get_params,
|
||||||
)
|
)
|
||||||
from zipformer import BypassModule
|
from zipformer import BypassModule
|
||||||
|
|
||||||
|
from icefall.profiler import get_model_profile
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -77,11 +77,10 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from k2 import SymbolTable
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
from k2 import SymbolTable
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
from typing import Dict
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
from typing import Dict
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,11 +27,10 @@ https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
from typing import Dict
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -15,15 +15,16 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
|
||||||
import logging
|
import logging
|
||||||
import k2
|
|
||||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
|
||||||
import random
|
|
||||||
import torch
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
|
|
||||||
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
|
||||||
|
@ -51,7 +51,7 @@ from streaming_beam_search import (
|
|||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
@ -16,11 +16,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
@ -34,6 +33,7 @@ from scaling import (
|
|||||||
SwooshR,
|
SwooshR,
|
||||||
Whiten,
|
Whiten,
|
||||||
)
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(nn.Module):
|
||||||
|
@ -858,7 +858,9 @@ def main():
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_model(params)
|
model = get_model(params)
|
||||||
import pdb; pdb.set_trace()
|
import pdb
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -877,9 +879,13 @@ def main():
|
|||||||
)
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device), strict=False)
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
elif params.avg == 1:
|
elif params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False)
|
load_checkpoint(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
@ -888,7 +894,9 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device), strict=False)
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
@ -917,7 +925,7 @@ def main():
|
|||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
strict=False
|
strict=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert params.avg > 0, params.avg
|
assert params.avg > 0, params.avg
|
||||||
@ -936,7 +944,7 @@ def main():
|
|||||||
filename_end=filename_end,
|
filename_end=filename_end,
|
||||||
device=device,
|
device=device,
|
||||||
),
|
),
|
||||||
strict=False
|
strict=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
@ -121,7 +121,7 @@ from beam_search import (
|
|||||||
modified_beam_search_lm_shallow_fusion,
|
modified_beam_search_lm_shallow_fusion,
|
||||||
modified_beam_search_LODR,
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
|
@ -72,7 +72,7 @@ import torch.nn as nn
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, add_finetune_arguments, get_model, get_params
|
from train import add_finetune_arguments, add_model_arguments, get_model, get_params
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
|
@ -77,11 +77,10 @@ from typing import List, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from k2 import SymbolTable
|
||||||
from onnx_pretrained import greedy_search, OnnxModel
|
from onnx_pretrained import OnnxModel, greedy_search
|
||||||
|
|
||||||
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
from icefall.utils import setup_logger, store_transcripts, write_error_stats
|
||||||
from k2 import SymbolTable
|
|
||||||
|
|
||||||
conversational_filler = [
|
conversational_filler = [
|
||||||
"UH",
|
"UH",
|
||||||
@ -182,6 +181,7 @@ def get_parser():
|
|||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def post_processing(
|
def post_processing(
|
||||||
results: List[Tuple[str, List[str], List[str]]],
|
results: List[Tuple[str, List[str], List[str]]],
|
||||||
) -> List[Tuple[str, List[str], List[str]]]:
|
) -> List[Tuple[str, List[str], List[str]]]:
|
||||||
@ -192,6 +192,7 @@ def post_processing(
|
|||||||
new_results.append((key, new_ref, new_hyp))
|
new_results.append((key, new_ref, new_hyp))
|
||||||
return new_results
|
return new_results
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
model: OnnxModel, token_table: SymbolTable, batch: dict
|
model: OnnxModel, token_table: SymbolTable, batch: dict
|
||||||
) -> List[List[str]]:
|
) -> List[List[str]]:
|
||||||
|
@ -137,14 +137,14 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--use-adapters",
|
"--use-adapters",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=True,
|
default=True,
|
||||||
help="If use adapter to finetune the model"
|
help="If use adapter to finetune the model",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--adapter-dim",
|
"--adapter-dim",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="The bottleneck dimension of the adapter"
|
help="The bottleneck dimension of the adapter",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -1273,7 +1273,11 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))
|
logging.info(
|
||||||
|
"A total of {} trainable parameters ({:.3f}% of the whole model)".format(
|
||||||
|
num_trainable, num_trainable / num_param * 100
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
@ -40,13 +40,13 @@ from scaling import (
|
|||||||
Dropout2,
|
Dropout2,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
|
SwooshL,
|
||||||
|
SwooshR,
|
||||||
Whiten,
|
Whiten,
|
||||||
convert_num_channels,
|
convert_num_channels,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
SwooshL,
|
|
||||||
SwooshR,
|
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
|
See https://k2-fsa.github.io/icefall/recipes/TTS/ljspeech/vits.html for detailed tutorials.
|
||||||
|
|
||||||
Training logs, Tensorboard logs, and checkpoints are uploaded to https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2023-11-29.
|
Training logs, Tensorboard logs, and checkpoints are uploaded to
|
||||||
|
https://huggingface.co/Zengwei/icefall-tts-ljspeech-vits-2024-02-28
|
||||||
|
@ -91,7 +91,7 @@ def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
|||||||
for key, value in meta_data.items():
|
for key, value in meta_data.items():
|
||||||
meta = model.metadata_props.add()
|
meta = model.metadata_props.add()
|
||||||
meta.key = key
|
meta.key = key
|
||||||
meta.value = value
|
meta.value = str(value)
|
||||||
|
|
||||||
onnx.save(model, filename)
|
onnx.save(model, filename)
|
||||||
|
|
||||||
@ -199,10 +199,15 @@ def export_model_onnx(
|
|||||||
)
|
)
|
||||||
|
|
||||||
meta_data = {
|
meta_data = {
|
||||||
"model_type": "VITS",
|
"model_type": "vits",
|
||||||
"version": "1",
|
"version": "1",
|
||||||
"model_author": "k2-fsa",
|
"model_author": "k2-fsa",
|
||||||
"comment": "VITS generator",
|
"comment": "icefall", # must be icefall for models from icefall
|
||||||
|
"language": "English",
|
||||||
|
"voice": "en-us", # Choose your language appropriately
|
||||||
|
"has_espeak": 1,
|
||||||
|
"n_speakers": 1,
|
||||||
|
"sample_rate": 22050, # Must match the real sample rate
|
||||||
}
|
}
|
||||||
logging.info(f"meta_data: {meta_data}")
|
logging.info(f"meta_data: {meta_data}")
|
||||||
|
|
||||||
@ -268,3 +273,144 @@ if __name__ == "__main__":
|
|||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
"""
|
||||||
|
Supported languages.
|
||||||
|
|
||||||
|
LJSpeech is using "en-us" from the second column.
|
||||||
|
|
||||||
|
Pty Language Age/Gender VoiceName File Other Languages
|
||||||
|
5 af --/M Afrikaans gmw/af
|
||||||
|
5 am --/M Amharic sem/am
|
||||||
|
5 an --/M Aragonese roa/an
|
||||||
|
5 ar --/M Arabic sem/ar
|
||||||
|
5 as --/M Assamese inc/as
|
||||||
|
5 az --/M Azerbaijani trk/az
|
||||||
|
5 ba --/M Bashkir trk/ba
|
||||||
|
5 be --/M Belarusian zle/be
|
||||||
|
5 bg --/M Bulgarian zls/bg
|
||||||
|
5 bn --/M Bengali inc/bn
|
||||||
|
5 bpy --/M Bishnupriya_Manipuri inc/bpy
|
||||||
|
5 bs --/M Bosnian zls/bs
|
||||||
|
5 ca --/M Catalan roa/ca
|
||||||
|
5 chr-US-Qaaa-x-west --/M Cherokee_ iro/chr
|
||||||
|
5 cmn --/M Chinese_(Mandarin,_latin_as_English) sit/cmn (zh-cmn 5)(zh 5)
|
||||||
|
5 cmn-latn-pinyin --/M Chinese_(Mandarin,_latin_as_Pinyin) sit/cmn-Latn-pinyin (zh-cmn 5)(zh 5)
|
||||||
|
5 cs --/M Czech zlw/cs
|
||||||
|
5 cv --/M Chuvash trk/cv
|
||||||
|
5 cy --/M Welsh cel/cy
|
||||||
|
5 da --/M Danish gmq/da
|
||||||
|
5 de --/M German gmw/de
|
||||||
|
5 el --/M Greek grk/el
|
||||||
|
5 en-029 --/M English_(Caribbean) gmw/en-029 (en 10)
|
||||||
|
2 en-gb --/M English_(Great_Britain) gmw/en (en 2)
|
||||||
|
5 en-gb-scotland --/M English_(Scotland) gmw/en-GB-scotland (en 4)
|
||||||
|
5 en-gb-x-gbclan --/M English_(Lancaster) gmw/en-GB-x-gbclan (en-gb 3)(en 5)
|
||||||
|
5 en-gb-x-gbcwmd --/M English_(West_Midlands) gmw/en-GB-x-gbcwmd (en-gb 9)(en 9)
|
||||||
|
5 en-gb-x-rp --/M English_(Received_Pronunciation) gmw/en-GB-x-rp (en-gb 4)(en 5)
|
||||||
|
2 en-us --/M English_(America) gmw/en-US (en 3)
|
||||||
|
5 en-us-nyc --/M English_(America,_New_York_City) gmw/en-US-nyc
|
||||||
|
5 eo --/M Esperanto art/eo
|
||||||
|
5 es --/M Spanish_(Spain) roa/es
|
||||||
|
5 es-419 --/M Spanish_(Latin_America) roa/es-419 (es-mx 6)
|
||||||
|
5 et --/M Estonian urj/et
|
||||||
|
5 eu --/M Basque eu
|
||||||
|
5 fa --/M Persian ira/fa
|
||||||
|
5 fa-latn --/M Persian_(Pinglish) ira/fa-Latn
|
||||||
|
5 fi --/M Finnish urj/fi
|
||||||
|
5 fr-be --/M French_(Belgium) roa/fr-BE (fr 8)
|
||||||
|
5 fr-ch --/M French_(Switzerland) roa/fr-CH (fr 8)
|
||||||
|
5 fr-fr --/M French_(France) roa/fr (fr 5)
|
||||||
|
5 ga --/M Gaelic_(Irish) cel/ga
|
||||||
|
5 gd --/M Gaelic_(Scottish) cel/gd
|
||||||
|
5 gn --/M Guarani sai/gn
|
||||||
|
5 grc --/M Greek_(Ancient) grk/grc
|
||||||
|
5 gu --/M Gujarati inc/gu
|
||||||
|
5 hak --/M Hakka_Chinese sit/hak
|
||||||
|
5 haw --/M Hawaiian map/haw
|
||||||
|
5 he --/M Hebrew sem/he
|
||||||
|
5 hi --/M Hindi inc/hi
|
||||||
|
5 hr --/M Croatian zls/hr (hbs 5)
|
||||||
|
5 ht --/M Haitian_Creole roa/ht
|
||||||
|
5 hu --/M Hungarian urj/hu
|
||||||
|
5 hy --/M Armenian_(East_Armenia) ine/hy (hy-arevela 5)
|
||||||
|
5 hyw --/M Armenian_(West_Armenia) ine/hyw (hy-arevmda 5)(hy 8)
|
||||||
|
5 ia --/M Interlingua art/ia
|
||||||
|
5 id --/M Indonesian poz/id
|
||||||
|
5 io --/M Ido art/io
|
||||||
|
5 is --/M Icelandic gmq/is
|
||||||
|
5 it --/M Italian roa/it
|
||||||
|
5 ja --/M Japanese jpx/ja
|
||||||
|
5 jbo --/M Lojban art/jbo
|
||||||
|
5 ka --/M Georgian ccs/ka
|
||||||
|
5 kk --/M Kazakh trk/kk
|
||||||
|
5 kl --/M Greenlandic esx/kl
|
||||||
|
5 kn --/M Kannada dra/kn
|
||||||
|
5 ko --/M Korean ko
|
||||||
|
5 kok --/M Konkani inc/kok
|
||||||
|
5 ku --/M Kurdish ira/ku
|
||||||
|
5 ky --/M Kyrgyz trk/ky
|
||||||
|
5 la --/M Latin itc/la
|
||||||
|
5 lb --/M Luxembourgish gmw/lb
|
||||||
|
5 lfn --/M Lingua_Franca_Nova art/lfn
|
||||||
|
5 lt --/M Lithuanian bat/lt
|
||||||
|
5 ltg --/M Latgalian bat/ltg
|
||||||
|
5 lv --/M Latvian bat/lv
|
||||||
|
5 mi --/M Māori poz/mi
|
||||||
|
5 mk --/M Macedonian zls/mk
|
||||||
|
5 ml --/M Malayalam dra/ml
|
||||||
|
5 mr --/M Marathi inc/mr
|
||||||
|
5 ms --/M Malay poz/ms
|
||||||
|
5 mt --/M Maltese sem/mt
|
||||||
|
5 mto --/M Totontepec_Mixe miz/mto
|
||||||
|
5 my --/M Myanmar_(Burmese) sit/my
|
||||||
|
5 nb --/M Norwegian_Bokmål gmq/nb (no 5)
|
||||||
|
5 nci --/M Nahuatl_(Classical) azc/nci
|
||||||
|
5 ne --/M Nepali inc/ne
|
||||||
|
5 nl --/M Dutch gmw/nl
|
||||||
|
5 nog --/M Nogai trk/nog
|
||||||
|
5 om --/M Oromo cus/om
|
||||||
|
5 or --/M Oriya inc/or
|
||||||
|
5 pa --/M Punjabi inc/pa
|
||||||
|
5 pap --/M Papiamento roa/pap
|
||||||
|
5 piqd --/M Klingon art/piqd
|
||||||
|
5 pl --/M Polish zlw/pl
|
||||||
|
5 pt --/M Portuguese_(Portugal) roa/pt (pt-pt 5)
|
||||||
|
5 pt-br --/M Portuguese_(Brazil) roa/pt-BR (pt 6)
|
||||||
|
5 py --/M Pyash art/py
|
||||||
|
5 qdb --/M Lang_Belta art/qdb
|
||||||
|
5 qu --/M Quechua qu
|
||||||
|
5 quc --/M K'iche' myn/quc
|
||||||
|
5 qya --/M Quenya art/qya
|
||||||
|
5 ro --/M Romanian roa/ro
|
||||||
|
5 ru --/M Russian zle/ru
|
||||||
|
5 ru-cl --/M Russian_(Classic) zle/ru-cl
|
||||||
|
2 ru-lv --/M Russian_(Latvia) zle/ru-LV
|
||||||
|
5 sd --/M Sindhi inc/sd
|
||||||
|
5 shn --/M Shan_(Tai_Yai) tai/shn
|
||||||
|
5 si --/M Sinhala inc/si
|
||||||
|
5 sjn --/M Sindarin art/sjn
|
||||||
|
5 sk --/M Slovak zlw/sk
|
||||||
|
5 sl --/M Slovenian zls/sl
|
||||||
|
5 smj --/M Lule_Saami urj/smj
|
||||||
|
5 sq --/M Albanian ine/sq
|
||||||
|
5 sr --/M Serbian zls/sr
|
||||||
|
5 sv --/M Swedish gmq/sv
|
||||||
|
5 sw --/M Swahili bnt/sw
|
||||||
|
5 ta --/M Tamil dra/ta
|
||||||
|
5 te --/M Telugu dra/te
|
||||||
|
5 th --/M Thai tai/th
|
||||||
|
5 tk --/M Turkmen trk/tk
|
||||||
|
5 tn --/M Setswana bnt/tn
|
||||||
|
5 tr --/M Turkish trk/tr
|
||||||
|
5 tt --/M Tatar trk/tt
|
||||||
|
5 ug --/M Uyghur trk/ug
|
||||||
|
5 uk --/M Ukrainian zle/uk
|
||||||
|
5 ur --/M Urdu inc/ur
|
||||||
|
5 uz --/M Uzbek trk/uz
|
||||||
|
5 vi --/M Vietnamese_(Northern) aav/vi
|
||||||
|
5 vi-vn-x-central --/M Vietnamese_(Central) aav/vi-VN-x-central
|
||||||
|
5 vi-vn-x-south --/M Vietnamese_(Southern) aav/vi-VN-x-south
|
||||||
|
5 yue --/M Chinese_(Cantonese) sit/yue (zh-yue 5)(zh 8)
|
||||||
|
5 yue --/M Chinese_(Cantonese,_latin_as_Jyutping) sit/yue-Latn-jyutping (zh-yue 5)(zh 8)
|
||||||
|
"""
|
||||||
|
@ -5,9 +5,9 @@ This file prints the text field of supervisions from cutset to the console
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from lhotse import load_manifest_lazy
|
from lhotse import load_manifest_lazy
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
|
@ -5,7 +5,6 @@ This file generates words.txt from the given transcript file.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import SwitchBoardAsrDataModule
|
from asr_datamodule import SwitchBoardAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from sclite_scoring import asr_text_post_processing
|
from sclite_scoring import asr_text_post_processing
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
|
@ -16,8 +16,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall import smart_byte_decode
|
from icefall import smart_byte_decode
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask
|
||||||
from scaling import ScaledLinear
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
|
@ -17,10 +17,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
|
||||||
import lhotse
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import lhotse
|
||||||
|
import torch
|
||||||
from lhotse import (
|
from lhotse import (
|
||||||
CutSet,
|
CutSet,
|
||||||
Fbank,
|
Fbank,
|
||||||
@ -29,6 +29,7 @@ from lhotse import (
|
|||||||
fix_manifests,
|
fix_manifests,
|
||||||
validate_recordings_and_supervisions,
|
validate_recordings_and_supervisions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
|
@ -41,6 +41,7 @@ from prepare_lang import (
|
|||||||
write_lexicon,
|
write_lexicon,
|
||||||
write_mapping,
|
write_mapping,
|
||||||
)
|
)
|
||||||
|
|
||||||
from icefall.utils import text_to_pinyin
|
from icefall.utils import text_to_pinyin
|
||||||
|
|
||||||
|
|
||||||
|
@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from icefall import is_module_available
|
import torch
|
||||||
from onnx_pretrained import OnnxModel
|
from onnx_pretrained import OnnxModel
|
||||||
|
|
||||||
import torch
|
from icefall import is_module_available
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -30,9 +30,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import keywords_search
|
||||||
keywords_search,
|
|
||||||
)
|
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
@ -87,6 +87,19 @@ from torch import Tensor
|
|||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from train import (
|
||||||
|
add_model_arguments,
|
||||||
|
add_training_arguments,
|
||||||
|
compute_validation_loss,
|
||||||
|
display_and_save_batch,
|
||||||
|
get_adjusted_batch_count,
|
||||||
|
get_model,
|
||||||
|
get_params,
|
||||||
|
load_checkpoint_if_available,
|
||||||
|
save_checkpoint,
|
||||||
|
scan_pessimistic_batches_for_oom,
|
||||||
|
set_batch_count,
|
||||||
|
)
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
@ -109,21 +122,6 @@ from icefall.utils import (
|
|||||||
text_to_pinyin,
|
text_to_pinyin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from train import (
|
|
||||||
add_model_arguments,
|
|
||||||
add_training_arguments,
|
|
||||||
compute_validation_loss,
|
|
||||||
display_and_save_batch,
|
|
||||||
get_adjusted_batch_count,
|
|
||||||
get_model,
|
|
||||||
get_params,
|
|
||||||
load_checkpoint_if_available,
|
|
||||||
save_checkpoint,
|
|
||||||
scan_pessimistic_batches_for_oom,
|
|
||||||
set_batch_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +99,6 @@ from icefall.utils import (
|
|||||||
text_to_pinyin,
|
text_to_pinyin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,9 +18,8 @@ you can use ./export.py --jit 1
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
|
||||||
import math
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
|
@ -8,7 +8,6 @@
|
|||||||
import re
|
import re
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
|
||||||
|
|
||||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||||
SPACE = chr(32)
|
SPACE = chr(32)
|
||||||
SPACE_ESCAPE = chr(9601)
|
SPACE_ESCAPE = chr(9601)
|
||||||
|
@ -8,12 +8,12 @@ The lang_dir should contain the following files:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import kaldifst
|
import kaldifst
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
class Lexicon:
|
class Lexicon:
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple, List
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
@ -5,14 +5,15 @@
|
|||||||
|
|
||||||
# This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py
|
# This is modified from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from functools import partial
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from functools import partial
|
|
||||||
from typing import List, Optional
|
|
||||||
from collections import OrderedDict
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
Tensor = torch.Tensor
|
Tensor = torch.Tensor
|
||||||
|
|
||||||
|
@ -5,16 +5,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import onnx
|
import onnx
|
||||||
import torch
|
import torch
|
||||||
from model import RnnLmModel
|
from model import RnnLmModel
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
from train import get_params
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, str2bool
|
||||||
from typing import Dict
|
|
||||||
from train import get_params
|
|
||||||
|
|
||||||
|
|
||||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||||
|
@ -8,6 +8,10 @@ pypinyin==0.50.0
|
|||||||
tensorboard
|
tensorboard
|
||||||
typeguard
|
typeguard
|
||||||
dill
|
dill
|
||||||
black==22.3.0
|
|
||||||
onnx==1.15.0
|
onnx==1.15.0
|
||||||
onnxruntime==1.16.3
|
onnxruntime==1.16.3
|
||||||
|
|
||||||
|
# style check session:
|
||||||
|
black==22.3.0
|
||||||
|
isort==5.10.1
|
||||||
|
flake8==5.0.4
|
Loading…
x
Reference in New Issue
Block a user