Merge branch 'k2-fsa:master' into tiny

This commit is contained in:
Tiance Wang 2023-03-25 10:41:35 +08:00 committed by GitHub
commit f021f53fe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 25 additions and 172 deletions

View File

@ -282,7 +282,6 @@ class ScaledAdam(BatchedOptimizer):
batch_size = p.shape[0] batch_size = p.shape[0]
numel = p.numel() // batch_size numel = p.numel() // batch_size
numel = p.numel()
if numel > 1: if numel > 1:
# "param_rms" just periodically records the scalar root-mean-square value of # "param_rms" just periodically records the scalar root-mean-square value of

View File

@ -781,13 +781,12 @@ class AttentionDownsample(torch.nn.Module):
ds = self.downsample ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample # Pad to an exact multiple of self.downsample, could be 0 for onnx-export-compatibility
if seq_len != d_seq_len * ds: # right-pad src, repeating the last element.
# right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len
pad = d_seq_len * ds - seq_len src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0)
src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds)
assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds)
src = src.reshape(d_seq_len, ds, batch_size, in_channels) src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True) scores = (src * self.query).sum(dim=-1, keepdim=True)

View File

@ -20,7 +20,7 @@ import logging
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
@ -69,7 +69,7 @@ def compute_fbank_wenetspeech_dev_test():
storage_path=f"{in_out_dir}/feats_{partition}", storage_path=f"{in_out_dir}/feats_{partition}",
num_workers=num_workers, num_workers=num_workers,
batch_duration=batch_duration, batch_duration=batch_duration,
storage_type=LilcomHdf5Writer, storage_type=LilcomChunkyWriter,
overwrite=True, overwrite=True,
) )

View File

@ -46,9 +46,6 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
set_caching_enabled(False)
torch.set_num_threads(1)
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
@ -348,24 +345,18 @@ class WenetSpeechAsrDataModule:
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False, shuffle=False,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
dev_iter_dataset = IterableDatasetWrapper(
dataset=validate,
sampler=valid_sampler,
)
valid_dl = DataLoader( valid_dl = DataLoader(
dev_iter_dataset, validate,
batch_size=None, batch_size=None,
sampler=valid_sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False, persistent_workers=False,
) )
@ -383,19 +374,13 @@ class WenetSpeechAsrDataModule:
sampler = DynamicBucketingSampler( sampler = DynamicBucketingSampler(
cuts, cuts,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
rank=0,
world_size=1,
shuffle=False, shuffle=False,
) )
from lhotse.dataset.iterable_dataset import IterableDatasetWrapper
test_iter_dataset = IterableDatasetWrapper(
dataset=test,
sampler=sampler,
)
test_dl = DataLoader( test_dl = DataLoader(
test_iter_dataset, test,
batch_size=None, batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
) )
return test_dl return test_dl

View File

@ -651,83 +651,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev" dev_cuts = wenetspeech.valid_cuts()
test_net = "test_net" dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_meeting = "test_meeting"
if not os.path.exists(f"{dev}/shared-0.tar"): test_net_cuts = wenetspeech.test_net_cuts()
os.makedirs(dev) test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"): test_meeting_cuts = wenetspeech.test_meeting_cuts()
os.makedirs(test_net) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dl = [dev_dl, test_net_dl, test_meeting_dl]

View File

@ -661,83 +661,18 @@ def main():
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# Note: Please use "pip install webdataset==0.1.103"
# for installing the webdataset.
import glob
import os
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
# we need cut ids to display recognition results. # we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args) wenetspeech = WenetSpeechAsrDataModule(args)
dev = "dev" dev_cuts = wenetspeech.valid_cuts()
test_net = "test_net" dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
test_meeting = "test_meeting"
if not os.path.exists(f"{dev}/shared-0.tar"): test_net_cuts = wenetspeech.test_net_cuts()
os.makedirs(dev) test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
dev_cuts = wenetspeech.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"): test_meeting_cuts = wenetspeech.test_meeting_cuts()
os.makedirs(test_net) test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
test_net_cuts = wenetspeech.test_net_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meeting}/shared-0.tar"):
os.makedirs(test_meeting)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meeting}/shared-%d.tar",
shard_size=300,
)
dev_shards = [
str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar")))
]
cuts_dev_webdataset = CutSet.from_webdataset(
dev_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_net_shards = [
str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meeting_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar")))
]
cuts_test_meeting_webdataset = CutSet.from_webdataset(
test_meeting_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meeting_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl] test_dl = [dev_dl, test_net_dl, test_meeting_dl]