This commit is contained in:
AmirHussein96 2024-04-05 13:08:04 -04:00
parent 891cf55901
commit e14dae4b11
18 changed files with 62 additions and 48 deletions

View File

@ -9,6 +9,7 @@ This file cer from icefall decoded "recogs" file:
""" """
import argparse import argparse
import jiwer import jiwer

View File

@ -23,20 +23,19 @@ It looks for manifests in the directory data_seame/manifests.
The generated fbank features are saved in data_seame/fbank. The generated fbank features are saved in data_seame/fbank.
""" """
import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
import argparse
from lhotse import CutSet, LilcomChunkyWriter from lhotse import CutSet, LilcomChunkyWriter
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.features.kaldifeat import ( from lhotse.features.kaldifeat import (
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,
KaldifeatFrameOptions, KaldifeatFrameOptions,
KaldifeatMelOptions, KaldifeatMelOptions,
) )
from lhotse.recipes.utils import read_manifests_if_cached
def get_args(): def get_args():

View File

@ -23,13 +23,12 @@ It looks for manifests in the directory data_seame/manifests.
The generated fbank features are saved in data_seame/fbank. The generated fbank features are saved in data_seame/fbank.
""" """
import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
import argparse
from lhotse import CutSet, LilcomChunkyWriter from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import ( from lhotse.features.kaldifeat import (
KaldifeatFbank, KaldifeatFbank,
KaldifeatFbankConfig, KaldifeatFbankConfig,

View File

@ -1,11 +1,12 @@
#!/usr/bin/python #!/usr/bin/python
from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse import argparse
import logging import logging
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
import pdb import pdb
from lhotse import CutSet, RecordingSet, SupervisionSet
from lhotse.qa import fix_manifests, validate_recordings_and_supervisions
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -35,6 +35,7 @@ and generates the following files in the directory `lang_dir`:
""" """
import argparse import argparse
import pdb
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -50,7 +51,6 @@ from prepare_lang import (
) )
from icefall.utils import str2bool from icefall.utils import str2bool
import pdb
def lexicon_to_fst_no_sil( def lexicon_to_fst_no_sil(

View File

@ -5,12 +5,13 @@
This script prepares transcript_words.txt from cutset This script prepares transcript_words.txt from cutset
""" """
from lhotse import CutSet
import argparse import argparse
import logging import logging
import os
import pdb import pdb
from pathlib import Path from pathlib import Path
import os
from lhotse import CutSet
def get_parser(): def get_parser():

View File

@ -5,12 +5,13 @@
Sample data given duration in seconds. Sample data given duration in seconds.
""" """
from lhotse import RecordingSet, SupervisionSet, CutSet
import argparse import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from lhotse import CutSet, RecordingSet, SupervisionSet
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -29,6 +29,7 @@ import argparse
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
import sentencepiece as spm import sentencepiece as spm

View File

@ -5,11 +5,16 @@
Compute WER per language Compute WER per language
""" """
import sys, codecs, math, pickle, unicodedata, re
from collections import Counter
import argparse import argparse
import codecs
import math
import pickle
import re
import sys
import unicodedata
from collections import Counter, defaultdict
from kaldialign import align from kaldialign import align
from collections import defaultdict
def get_parser(): def get_parser():

View File

@ -64,6 +64,8 @@ import argparse
import logging import logging
import math import math
import os import os
import re
import string
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -105,9 +107,6 @@ from icefall.utils import (
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
import string
import re
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)

View File

@ -22,24 +22,24 @@ Usage: ./zipformer/profile.py
import argparse import argparse
import logging import logging
from typing import Tuple
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from typing import Tuple
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.profiler import get_model_profile
from scaling import BiasNorm from scaling import BiasNorm
from torch import Tensor, nn
from train import ( from train import (
add_model_arguments,
get_encoder_embed, get_encoder_embed,
get_encoder_model, get_encoder_model,
get_joiner_model, get_joiner_model,
add_model_arguments,
get_params, get_params,
) )
from zipformer import BypassModule from zipformer import BypassModule
from icefall.profiler import get_model_profile
from icefall.utils import make_pad_mask
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -52,6 +52,8 @@ import argparse
import logging import logging
import math import math
import os import os
import re
import string
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -64,8 +66,8 @@ from asr_datamodule import SeameAsrDataModule
from beam_search import ( from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_lm_rescore_LODR, modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -86,8 +88,6 @@ from icefall.utils import (
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
import string
import re
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)

View File

@ -22,9 +22,9 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):

View File

@ -75,11 +75,15 @@ import argparse
import logging import logging
import math import math
import os import os
import re
import string
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import matplotlib.pyplot as plt
import seaborn as sns
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -87,10 +91,12 @@ from asr_datamodule import SeameAsrDataModule
from beam_search import ( from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_lm_rescore_LODR, modified_beam_search_lm_rescore_LODR,
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from kaldialign import align
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
@ -109,12 +115,6 @@ from icefall.utils import (
str2bool, str2bool,
write_error_stats, write_error_stats,
) )
from kaldialign import align
from sklearn.metrics import f1_score, classification_report, confusion_matrix
import string
import re
import seaborn as sns
import matplotlib.pyplot as plt
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)

View File

@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling import ScaledLinear from scaling import ScaledLinear
from typing import Optional
class Joiner(nn.Module): class Joiner(nn.Module):

View File

@ -23,9 +23,9 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask
from scaling import ScaledLinear
class AsrModel(nn.Module): class AsrModel(nn.Module):

View File

@ -100,7 +100,7 @@ import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
from torch.optim import Optimizer
import k2 import k2
import optim import optim
import sentencepiece as spm import sentencepiece as spm
@ -120,6 +120,7 @@ from subsampling import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2

View File

@ -17,28 +17,33 @@
# limitations under the License. # limitations under the License.
import copy import copy
import logging
import math import math
import random
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import logging
import torch import torch
import random
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ( from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationDropoutAndLinear,
Balancer, Balancer,
BiasNorm, BiasNorm,
Dropout2,
ChunkCausalDepthwiseConv1d, ChunkCausalDepthwiseConv1d,
ActivationDropoutAndLinear, Dropout2,
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. FloatLike,
ScheduledFloat,
Whiten, Whiten,
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. convert_num_channels,
limit_param_value,
penalize_abs_values_gt, penalize_abs_values_gt,
softmax, softmax,
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
) )
from torch import Tensor, nn from torch import Tensor, nn