mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
training libricss surt model
This commit is contained in:
parent
9ed22396a9
commit
d50cef82cc
372
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
372
egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,372 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriCssAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
return_sources: bool = True,
|
||||
strict: bool = True,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=return_sources,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def lsmix_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
type_affix: str = "full",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir
|
||||
/ f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
|
||||
logging.info(f"About to get LibriCSS {split} {type} cuts")
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
|
||||
)
|
||||
return cs
|
1280
egs/libricss/SURT/dprnn_zipformer/beam_search.py
Normal file
1280
egs/libricss/SURT/dprnn_zipformer/beam_search.py
Normal file
File diff suppressed because it is too large
Load Diff
871
egs/libricss/SURT/dprnn_zipformer/decode.py
Executable file
871
egs/libricss/SURT/dprnn_zipformer/decode.py
Executable file
@ -0,0 +1,871 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--use-averaged-model true \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) modified beam search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--use-averaged-model true \
|
||||
--exp-dir ./dprnn_zipformer/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from itertools import chain, groupby, repeat
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriCssAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from lhotse.utils import EPSILON
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
If you use fast_beam_search_nbest_LG, you have to specify
|
||||
`--lang-dir`, which should contain `LG.pt`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save-masks",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""If true, save masks generated by unmixing module.""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
feature_lens = batch["input_lens"].to(device)
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = feature.shape
|
||||
processed = model.mask_encoder(feature) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
# To save the masks, we split them by batch and trim each mask to the length of
|
||||
# the corresponding feature. We save them in a dict, where the key is the
|
||||
# cut ID and the value is the mask.
|
||||
masks_dict = {}
|
||||
for i in range(B):
|
||||
mask = torch.cat(
|
||||
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
|
||||
dim=-1,
|
||||
)
|
||||
mask = mask.cpu().numpy()
|
||||
masks_dict[batch["cuts"][i].id] = mask
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
encoder_out = model.joint_encoder_layer(encoder_out)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
hyp_tokens = modified_beam_search_LODR(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
LODR_lm=ngram_lm,
|
||||
LODR_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": _group_channels(hyps)}, masks_dict
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: _group_channels(hyps)}, masks_dict
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
ngram_lm: Optional[NgramLm] = None,
|
||||
ngram_lm_scale: float = 1.0,
|
||||
LM: Optional[LmScorer] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
masks = {}
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict, masks_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
masks.update(masks_dict)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results, masks_dict
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
def save_masks(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
masks: List[torch.Tensor],
|
||||
):
|
||||
masks_path = params.res_dir / f"masks-{test_set_name}.txt"
|
||||
torch.save(masks, masks_path)
|
||||
logging.info(f"The masks are stored in {masks_path}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LmScorer.add_arguments(parser)
|
||||
LibriCssAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
"modified_beam_search_LODR",
|
||||
), f"Decoding method {params.decoding_method} is not supported."
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
if "LODR" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
# only load N-gram LM when needed
|
||||
if "LODR" in params.decoding_method:
|
||||
lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt"
|
||||
logging.info(f"lm filename: {lm_filename}")
|
||||
ngram_lm = NgramLm(
|
||||
lm_filename,
|
||||
backoff_id=params.backoff_id,
|
||||
is_binary=False,
|
||||
)
|
||||
logging.info(f"num states: {ngram_lm.lm.num_states}")
|
||||
ngram_lm_scale = params.ngram_lm_scale
|
||||
else:
|
||||
ngram_lm = None
|
||||
ngram_lm_scale = None
|
||||
|
||||
# only load the neural network LM if doing shallow fusion
|
||||
if params.use_shallow_fusion:
|
||||
LM = LmScorer(
|
||||
lm_type=params.lm_type,
|
||||
params=params,
|
||||
device=device,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
LM.to(device)
|
||||
LM.eval()
|
||||
|
||||
else:
|
||||
LM = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
libricss = LibriCssAsrDataModule(args)
|
||||
|
||||
dev_cuts = libricss.libricss_cuts(split="dev", type="ihm-mix").to_eager()
|
||||
dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
|
||||
test_cuts = libricss.libricss_cuts(split="test", type="ihm-mix").to_eager()
|
||||
test_cuts_grouped = [
|
||||
test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
|
||||
]
|
||||
|
||||
for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
|
||||
dev_dl = libricss.test_dataloaders(dev_set)
|
||||
results_dict, masks = decode_dataset(
|
||||
dl=dev_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=f"dev_{ol}",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.save_masks:
|
||||
save_masks(
|
||||
params=params,
|
||||
test_set_name=f"dev_{ol}",
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
|
||||
test_dl = libricss.test_dataloaders(test_set)
|
||||
results_dict, masks = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
ngram_lm=ngram_lm,
|
||||
ngram_lm_scale=ngram_lm_scale,
|
||||
LM=LM,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=f"test_{ol}",
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
if params.save_masks:
|
||||
save_masks(
|
||||
params=params,
|
||||
test_set_name=f"test_{ol}",
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
102
egs/libricss/SURT/dprnn_zipformer/decoder.py
Normal file
102
egs/libricss/SURT/dprnn_zipformer/decoder.py
Normal file
@ -0,0 +1,102 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
decoder_dim: int,
|
||||
blank_id: int,
|
||||
context_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
Number of previous words to use to predict the next word.
|
||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
if context_size > 1:
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
y:
|
||||
A 2-D tensor of shape (N, U).
|
||||
need_pad:
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
# this stuff about clamp() is a temporary fix for a mismatch
|
||||
# at utterance start, we use negative ids in beam_search.py
|
||||
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
305
egs/libricss/SURT/dprnn_zipformer/dprnn.py
Normal file
305
egs/libricss/SURT/dprnn_zipformer/dprnn.py
Normal file
@ -0,0 +1,305 @@
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM
|
||||
from torch.autograd import Variable
|
||||
|
||||
EPS = torch.finfo(torch.get_default_dtype()).eps
|
||||
|
||||
|
||||
def _pad_segment(input, segment_size):
|
||||
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342
|
||||
# input is the features: (B, N, T)
|
||||
batch_size, dim, seq_len = input.shape
|
||||
segment_stride = segment_size // 2
|
||||
|
||||
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
|
||||
if rest > 0:
|
||||
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
|
||||
input = torch.cat([input, pad], 2)
|
||||
|
||||
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
|
||||
input = torch.cat([pad_aux, input, pad_aux], 2)
|
||||
|
||||
return input, rest
|
||||
|
||||
|
||||
def split_feature(input, segment_size):
|
||||
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358
|
||||
# split the feature into chunks of segment size
|
||||
# input is the features: (B, N, T)
|
||||
|
||||
input, rest = _pad_segment(input, segment_size)
|
||||
batch_size, dim, seq_len = input.shape
|
||||
segment_stride = segment_size // 2
|
||||
|
||||
segments1 = (
|
||||
input[:, :, :-segment_stride]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1, segment_size)
|
||||
)
|
||||
segments2 = (
|
||||
input[:, :, segment_stride:]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1, segment_size)
|
||||
)
|
||||
segments = (
|
||||
torch.cat([segments1, segments2], 3)
|
||||
.view(batch_size, dim, -1, segment_size)
|
||||
.transpose(2, 3)
|
||||
)
|
||||
|
||||
return segments.contiguous(), rest
|
||||
|
||||
|
||||
def merge_feature(input, rest):
|
||||
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385
|
||||
# merge the splitted features into full utterance
|
||||
# input is the features: (B, N, L, K)
|
||||
|
||||
batch_size, dim, segment_size, _ = input.shape
|
||||
segment_stride = segment_size // 2
|
||||
input = (
|
||||
input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2)
|
||||
) # B, N, K, L
|
||||
|
||||
input1 = (
|
||||
input[:, :, :, :segment_size]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1)[:, :, segment_stride:]
|
||||
)
|
||||
input2 = (
|
||||
input[:, :, :, segment_size:]
|
||||
.contiguous()
|
||||
.view(batch_size, dim, -1)[:, :, :-segment_stride]
|
||||
)
|
||||
|
||||
output = input1 + input2
|
||||
if rest > 0:
|
||||
output = output[:, :, :-rest]
|
||||
|
||||
return output.contiguous() # B, N, T
|
||||
|
||||
|
||||
class RNNEncoderLayer(nn.Module):
|
||||
"""
|
||||
RNNEncoderLayer is made up of lstm and feedforward networks.
|
||||
Args:
|
||||
input_size:
|
||||
The number of expected features in the input (required).
|
||||
hidden_size:
|
||||
The hidden dimension of rnn layer.
|
||||
dropout:
|
||||
The dropout value (default=0.1).
|
||||
layer_dropout:
|
||||
The dropout value for model-level warmup (default=0.075).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
dropout: float = 0.1,
|
||||
bidirectional: bool = False,
|
||||
) -> None:
|
||||
super(RNNEncoderLayer, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
assert hidden_size >= input_size, (hidden_size, input_size)
|
||||
self.lstm = ScaledLSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size // 2 if bidirectional else hidden_size,
|
||||
proj_size=0,
|
||||
num_layers=1,
|
||||
dropout=0.0,
|
||||
batch_first=True,
|
||||
bidirectional=bidirectional,
|
||||
)
|
||||
self.norm_final = BasicNorm(input_size)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
|
||||
self.balancer = ActivationBalancer(
|
||||
num_channels=input_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
max_abs=6.0,
|
||||
)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
Args:
|
||||
src:
|
||||
The sequence to the encoder layer (required).
|
||||
Its shape is (S, N, E), where S is the sequence length,
|
||||
N is the batch size, and E is the feature number.
|
||||
states:
|
||||
A tuple of 2 tensors (optional). It is for streaming inference.
|
||||
states[0] is the hidden states of all layers,
|
||||
with shape of (1, N, input_size);
|
||||
states[1] is the cell states of all layers,
|
||||
with shape of (1, N, hidden_size).
|
||||
"""
|
||||
src_orig = src
|
||||
|
||||
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
|
||||
# completely bypass it.
|
||||
alpha = warmup if self.training else 1.0
|
||||
|
||||
# lstm module
|
||||
src_lstm, new_states = self.lstm(src, states)
|
||||
src = self.dropout(src_lstm) + src
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
return src
|
||||
|
||||
|
||||
# dual-path RNN
|
||||
class DPRNN(nn.Module):
|
||||
"""Deep dual-path RNN.
|
||||
Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py
|
||||
|
||||
args:
|
||||
input_size: int, dimension of the input feature. The input should have shape
|
||||
(batch, seq_len, input_size).
|
||||
hidden_size: int, dimension of the hidden state.
|
||||
output_size: int, dimension of the output size.
|
||||
dropout: float, dropout ratio. Default is 0.
|
||||
num_blocks: int, number of stacked RNN layers. Default is 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_dim,
|
||||
input_size,
|
||||
hidden_size,
|
||||
output_size,
|
||||
dropout=0.1,
|
||||
num_blocks=1,
|
||||
segment_size=50,
|
||||
chunk_width_randomization=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.segment_size = segment_size
|
||||
self.chunk_width_randomization = chunk_width_randomization
|
||||
|
||||
self.input_embed = nn.Sequential(
|
||||
ScaledLinear(feature_dim, input_size),
|
||||
BasicNorm(input_size),
|
||||
ActivationBalancer(
|
||||
num_channels=input_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
),
|
||||
)
|
||||
|
||||
# dual-path RNN
|
||||
self.row_rnn = nn.ModuleList([])
|
||||
self.col_rnn = nn.ModuleList([])
|
||||
for _ in range(num_blocks):
|
||||
# intra-RNN is non-causal
|
||||
self.row_rnn.append(
|
||||
RNNEncoderLayer(
|
||||
input_size, hidden_size, dropout=dropout, bidirectional=True
|
||||
)
|
||||
)
|
||||
self.col_rnn.append(
|
||||
RNNEncoderLayer(
|
||||
input_size, hidden_size, dropout=dropout, bidirectional=False
|
||||
)
|
||||
)
|
||||
|
||||
# output layer
|
||||
self.out_embed = nn.Sequential(
|
||||
ScaledLinear(input_size, output_size),
|
||||
BasicNorm(output_size),
|
||||
ActivationBalancer(
|
||||
num_channels=output_size,
|
||||
channel_dim=-1,
|
||||
min_positive=0.45,
|
||||
max_positive=0.55,
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
# input shape: B, T, F
|
||||
input = self.input_embed(input)
|
||||
B, T, D = input.shape
|
||||
|
||||
if self.chunk_width_randomization and self.training:
|
||||
segment_size = random.randint(self.segment_size // 2, self.segment_size)
|
||||
else:
|
||||
segment_size = self.segment_size
|
||||
input, rest = split_feature(input.transpose(1, 2), segment_size)
|
||||
# input shape: batch, N, dim1, dim2
|
||||
# apply RNN on dim1 first and then dim2
|
||||
# output shape: B, output_size, dim1, dim2
|
||||
# input = input.to(device)
|
||||
batch_size, _, dim1, dim2 = input.shape
|
||||
output = input
|
||||
for i in range(len(self.row_rnn)):
|
||||
row_input = (
|
||||
output.permute(0, 3, 2, 1)
|
||||
.contiguous()
|
||||
.view(batch_size * dim2, dim1, -1)
|
||||
) # B*dim2, dim1, N
|
||||
output = self.row_rnn[i](row_input) # B*dim2, dim1, H
|
||||
output = (
|
||||
output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
|
||||
) # B, N, dim1, dim2
|
||||
|
||||
col_input = (
|
||||
output.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
.view(batch_size * dim1, dim2, -1)
|
||||
) # B*dim1, dim2, N
|
||||
output = self.col_rnn[i](col_input) # B*dim1, dim2, H
|
||||
output = (
|
||||
output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
|
||||
) # B, N, dim1, dim2
|
||||
|
||||
output = merge_feature(output, rest)
|
||||
output = output.transpose(1, 2)
|
||||
output = self.out_embed(output)
|
||||
|
||||
# Apply ReLU to the output
|
||||
output = torch.relu(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
model = DPRNN(
|
||||
80,
|
||||
256,
|
||||
256,
|
||||
160,
|
||||
dropout=0.1,
|
||||
num_blocks=4,
|
||||
segment_size=32,
|
||||
chunk_width_randomization=True,
|
||||
)
|
||||
input = torch.randn(2, 1002, 80)
|
||||
print(sum(p.numel() for p in model.parameters()))
|
||||
print(model(input).shape)
|
43
egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
Normal file
43
egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class EncoderInterface(nn.Module):
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A tensor of shape (batch_size, input_seq_len, num_features)
|
||||
containing the input features.
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames
|
||||
in `x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing two tensors:
|
||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||
containing unnormalized probabilities, i.e., the output of a
|
||||
linear layer.
|
||||
- encoder_out_lens, a tensor of shape (batch_size,) containing
|
||||
the number of frames in `encoder_out` before padding.
|
||||
"""
|
||||
raise NotImplementedError("Please implement it in a subclass")
|
65
egs/libricss/SURT/dprnn_zipformer/joiner.py
Normal file
65
egs/libricss/SURT/dprnn_zipformer/joiner.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = nn.Linear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = nn.Linear(decoder_dim, joiner_dim)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
project_input: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
project_input:
|
||||
If true, apply input projections encoder_proj and decoder_proj.
|
||||
If this is false, it is the user's responsibility to do this
|
||||
manually.
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return logit
|
316
egs/libricss/SURT/dprnn_zipformer/model.py
Normal file
316
egs/libricss/SURT/dprnn_zipformer/model.py
Normal file
@ -0,0 +1,316 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
# Copyright 2023 Johns Hopkins University (author: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class SURT(nn.Module):
|
||||
"""It implements Streaming Unmixing and Recognition Transducer (SURT).
|
||||
https://arxiv.org/abs/2011.13148
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mask_encoder: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
joint_encoder_layer: Optional[nn.Module],
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
num_channels: int,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
mask_encoder:
|
||||
It is the masking network. It generates a mask for each channel of the
|
||||
encoder. These masks are applied to the input features, and then passed
|
||||
to the transcription network.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
num_channels:
|
||||
It is the number of channels that the input features will be split into.
|
||||
In general, it should be equal to the maximum number of simultaneously
|
||||
active speakers. For most real scenarios, using 2 channels is sufficient.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.mask_encoder = mask_encoder
|
||||
self.encoder = encoder
|
||||
self.joint_encoder_layer = joint_encoder_layer
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
self.num_channels = num_channels
|
||||
|
||||
self.simple_am_proj = nn.Linear(
|
||||
encoder_dim,
|
||||
vocab_size,
|
||||
)
|
||||
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
|
||||
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_helper(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
beam_size: int = 10,
|
||||
use_double_scores: bool = False,
|
||||
subsampling_factor: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute transducer loss for one branch of the SURT model.
|
||||
"""
|
||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
if self.joint_encoder_layer is not None:
|
||||
encoder_out = self.joint_encoder_layer(encoder_out)
|
||||
|
||||
# compute ctc log-probs
|
||||
ctc_output = self.ctc_output(encoder_out)
|
||||
|
||||
# For the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
# Compute ctc loss
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
torch.arange(len(x_lens), device="cpu"),
|
||||
torch.zeros_like(x_lens, device="cpu"),
|
||||
torch.clone(x_lens).detach().cpu(),
|
||||
),
|
||||
dim=1,
|
||||
).to(torch.int32)
|
||||
# We need to sort supervision_segments in decreasing order of num_frames
|
||||
indices = torch.argsort(supervision_segments[:, 2], descending=True)
|
||||
supervision_segments = supervision_segments[indices]
|
||||
|
||||
# Works with a BPE model
|
||||
decoding_graph = k2.ctc_graph(y, modified=False, device=x.device)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=subsampling_factor - 1,
|
||||
)
|
||||
ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=beam_size,
|
||||
reduction="none",
|
||||
use_double_scores=use_double_scores,
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss, ctc_loss)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
reduction: str = "sum",
|
||||
beam_size: int = 10,
|
||||
use_double_scores: bool = False,
|
||||
subsampling_factor: int = 1,
|
||||
return_masks: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor of shape (N*num_channels, S). It contains the labels
|
||||
of the N utterances. The labels are in the range [0, vocab_size). All
|
||||
the channels are concatenated together one after another.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
reduction:
|
||||
"sum" to sum the losses over all utterances in the batch.
|
||||
"none" to return the loss in a 1-D tensor for each utterance
|
||||
in the batch.
|
||||
beam_size:
|
||||
The beam size used in CTC decoding.
|
||||
use_double_scores:
|
||||
If True, use double precision for CTC decoding.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model. It is used to compute the
|
||||
supervision segments for CTC loss.
|
||||
return_masks:
|
||||
If True, return the masks as well as masked features.
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0), (x.size(), x_lens.size())
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = x.shape
|
||||
processed = self.mask_encoder(x) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, self.num_channels).unbind(dim=-1)
|
||||
x_masked = [x * m for m in masks]
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([x_lens for _ in range(self.num_channels)], dim=0)
|
||||
|
||||
simple_loss, pruned_loss, ctc_loss = self.forward_helper(
|
||||
h,
|
||||
h_lens,
|
||||
y,
|
||||
prune_range,
|
||||
am_scale,
|
||||
lm_scale,
|
||||
reduction=reduction,
|
||||
beam_size=beam_size,
|
||||
use_double_scores=use_double_scores,
|
||||
subsampling_factor=subsampling_factor,
|
||||
)
|
||||
|
||||
# Chunks the outputs into 2 parts along batch axis and then stack them along a new axis.
|
||||
simple_loss = torch.stack(
|
||||
torch.chunk(simple_loss, self.num_channels, dim=0), dim=0
|
||||
)
|
||||
pruned_loss = torch.stack(
|
||||
torch.chunk(pruned_loss, self.num_channels, dim=0), dim=0
|
||||
)
|
||||
ctc_loss = torch.stack(torch.chunk(ctc_loss, self.num_channels, dim=0), dim=0)
|
||||
|
||||
if return_masks:
|
||||
return (simple_loss, pruned_loss, ctc_loss, x_masked, masks)
|
||||
else:
|
||||
return (simple_loss, pruned_loss, ctc_loss, x_masked)
|
1061
egs/libricss/SURT/dprnn_zipformer/optim.py
Normal file
1061
egs/libricss/SURT/dprnn_zipformer/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
1533
egs/libricss/SURT/dprnn_zipformer/scaling.py
Normal file
1533
egs/libricss/SURT/dprnn_zipformer/scaling.py
Normal file
File diff suppressed because it is too large
Load Diff
114
egs/libricss/SURT/dprnn_zipformer/scaling_converter.py
Normal file
114
egs/libricss/SURT/dprnn_zipformer/scaling_converter.py
Normal file
@ -0,0 +1,114 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file replaces various modules in a model.
|
||||
Specifically, ActivationBalancer is replaced with an identity operator;
|
||||
Whiten is also replaced with an identity operator;
|
||||
BasicNorm is replaced by a module with `exp` removed.
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scaling import ActivationBalancer, BasicNorm, Whiten
|
||||
|
||||
|
||||
class NonScaledNorm(nn.Module):
|
||||
"""See BasicNorm for doc"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
eps_exp: float,
|
||||
channel_dim: int = -1, # CAUTION: see documentation.
|
||||
):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.channel_dim = channel_dim
|
||||
self.eps_exp = eps_exp
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.jit.is_tracing():
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
scales = (
|
||||
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
|
||||
).pow(-0.5)
|
||||
return x * scales
|
||||
|
||||
|
||||
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
|
||||
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
|
||||
norm = NonScaledNorm(
|
||||
num_channels=basic_norm.num_channels,
|
||||
eps_exp=basic_norm.eps.data.exp().item(),
|
||||
channel_dim=basic_norm.channel_dim,
|
||||
)
|
||||
return norm
|
||||
|
||||
|
||||
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
|
||||
# get_submodule was added to nn.Module at v1.9.0
|
||||
def get_submodule(model, target):
|
||||
if target == "":
|
||||
return model
|
||||
atoms: List[str] = target.split(".")
|
||||
mod: torch.nn.Module = model
|
||||
for item in atoms:
|
||||
if not hasattr(mod, item):
|
||||
raise AttributeError(
|
||||
mod._get_name() + " has no " "attribute `" + item + "`"
|
||||
)
|
||||
mod = getattr(mod, item)
|
||||
if not isinstance(mod, torch.nn.Module):
|
||||
raise AttributeError("`" + item + "` is not " "an nn.Module")
|
||||
return mod
|
||||
|
||||
|
||||
def convert_scaled_to_non_scaled(
|
||||
model: nn.Module,
|
||||
inplace: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model:
|
||||
The model to be converted.
|
||||
inplace:
|
||||
If True, the input model is modified inplace.
|
||||
If False, the input model is copied and we modify the copied version.
|
||||
Return:
|
||||
Return a model without scaled layers.
|
||||
"""
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
d = {}
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, BasicNorm):
|
||||
d[name] = convert_basic_norm(m)
|
||||
elif isinstance(m, (ActivationBalancer, Whiten)):
|
||||
d[name] = nn.Identity()
|
||||
|
||||
for k, v in d.items():
|
||||
if "." in k:
|
||||
parent, child = k.rsplit(".", maxsplit=1)
|
||||
setattr(get_submodule(model, parent), child, v)
|
||||
else:
|
||||
setattr(model, k, v)
|
||||
|
||||
return model
|
1452
egs/libricss/SURT/dprnn_zipformer/train.py
Executable file
1452
egs/libricss/SURT/dprnn_zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1343
egs/libricss/SURT/dprnn_zipformer/train_adapt.py
Executable file
1343
egs/libricss/SURT/dprnn_zipformer/train_adapt.py
Executable file
File diff suppressed because it is too large
Load Diff
2891
egs/libricss/SURT/dprnn_zipformer/zipformer.py
Normal file
2891
egs/libricss/SURT/dprnn_zipformer/zipformer.py
Normal file
File diff suppressed because it is too large
Load Diff
85
egs/libricss/SURT/local/add_source_feats.py
Executable file
85
egs/libricss/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,85 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds source features as temporal arrays to the mixture manifests.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def add_source_feats(num_jobs=1):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
for type_affix in ["full", "ov40"]:
|
||||
logging.info(f"Adding source features for {type_affix}")
|
||||
mixed_name_clean = f"train_clean_{type_affix}"
|
||||
mixed_name_rvb = f"train_rvb_{type_affix}"
|
||||
|
||||
logging.info("Reading mixed cuts")
|
||||
mixed_cuts_clean = load_manifest_lazy(
|
||||
src_dir / f"cuts_{mixed_name_clean}.jsonl.gz"
|
||||
)
|
||||
mixed_cuts_rvb = load_manifest_lazy(src_dir / f"cuts_{mixed_name_rvb}.jsonl.gz")
|
||||
|
||||
logging.info("Reading source cuts")
|
||||
source_cuts = load_manifest(src_dir / "librispeech_cuts_train_trimmed.jsonl.gz")
|
||||
|
||||
logging.info("Adding source features to the mixed cuts")
|
||||
with tqdm() as pbar, CutSet.open_writer(
|
||||
src_dir / f"cuts_{mixed_name_clean}_sources.jsonl.gz"
|
||||
) as cut_writer_clean, CutSet.open_writer(
|
||||
src_dir / f"cuts_{mixed_name_rvb}_sources.jsonl.gz"
|
||||
) as cut_writer_rvb, LilcomChunkyWriter(
|
||||
output_dir / f"feats_train_{type_affix}_sources"
|
||||
) as source_feat_writer:
|
||||
for cut_clean, cut_rvb in zip(mixed_cuts_clean, mixed_cuts_rvb):
|
||||
assert cut_rvb.id == cut_clean.id + "_rvb"
|
||||
# Create source_feats and source_feat_offsets
|
||||
# (See `lhotse.datasets.K2SurtDataset` for details)
|
||||
source_feats = []
|
||||
source_feat_offsets = []
|
||||
cur_offset = 0
|
||||
for sup in sorted(
|
||||
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||
):
|
||||
source_cut = source_cuts[sup.id]
|
||||
source_feats.append(source_cut.load_features())
|
||||
source_feat_offsets.append(cur_offset)
|
||||
cur_offset += source_cut.num_frames
|
||||
cut_clean.source_feats = source_feat_writer.store_array(
|
||||
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||
)
|
||||
cut_clean.source_feat_offsets = source_feat_offsets
|
||||
cut_writer_clean.write(cut_clean)
|
||||
cut_rvb.source_feats = cut_clean.source_feats
|
||||
cut_rvb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||
cut_writer_rvb.write(cut_rvb)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
add_source_feats()
|
@ -185,7 +185,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
for type in full ov40; do
|
||||
cat <(gunzip -c data/manifests/cuts_train_clean_${type}_sources.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_rvb_${type}_sources.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_${type}_sources.jsonl.gz
|
||||
shuf | gzip -c > data/manifests/cuts_train_comb_${type}_sources.jsonl.gz
|
||||
done
|
||||
fi
|
||||
|
||||
|
@ -887,8 +887,7 @@ def write_error_stats_with_timestamps(
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
|
||||
return tot_err_rate, mean_delay, var_delay
|
||||
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||
|
||||
|
||||
def write_surt_error_stats(
|
||||
@ -1282,10 +1281,10 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
|
||||
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||
expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths)
|
||||
|
||||
return expaned_lengths >= lengths.unsqueeze(1)
|
||||
|
||||
|
||||
# Copied and modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py
|
||||
@ -1648,7 +1647,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
List of timestamp of each word.
|
||||
"""
|
||||
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
|
||||
assert len(tokens) == len(timestamp)
|
||||
ans = []
|
||||
for i in range(len(tokens)):
|
||||
flag = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user