mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge 136c03d040ad8d81ae8d0ccf5a3a5d9d11d9c79c into cbf8c18ebd274dfeea9b8aa224ff5faad713c28c
This commit is contained in:
commit
711c1ccbcd
@ -19,7 +19,9 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||||
from lhotse.dataset import (
|
from lhotse.dataset import (
|
||||||
BucketingSampler,
|
BucketingSampler,
|
||||||
@ -179,7 +181,27 @@ class LibriSpeechAsrDataModule:
|
|||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
def train_dataloaders(
|
||||||
|
self,
|
||||||
|
cuts_train: CutSet,
|
||||||
|
extra_input_transforms: Optional[
|
||||||
|
List[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]
|
||||||
|
],
|
||||||
|
) -> DataLoader:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cuts_train:
|
||||||
|
The cutset for training.
|
||||||
|
extra_input_transforms:
|
||||||
|
The extra input transforms that will be applied after all input
|
||||||
|
transforms, e.g., after SpecAugment if there is any.
|
||||||
|
Each input transform accepts two input arguments:
|
||||||
|
- A 3-D torch.Tensor of shape (N, T, C)
|
||||||
|
- A 2-D torch.Tensor of shape (num_seqs, 3), where the
|
||||||
|
first column is `sequence_idx`, the second column is
|
||||||
|
`start_frame`, and the third column is `num_frames`.
|
||||||
|
and returns a 3-D torch.Tensor of shape (N, T, C).
|
||||||
|
"""
|
||||||
logging.info("About to get Musan cuts")
|
logging.info("About to get Musan cuts")
|
||||||
cuts_musan = load_manifest(
|
cuts_musan = load_manifest(
|
||||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||||
@ -228,6 +250,10 @@ class LibriSpeechAsrDataModule:
|
|||||||
else:
|
else:
|
||||||
logging.info("Disable SpecAugment")
|
logging.info("Disable SpecAugment")
|
||||||
|
|
||||||
|
if extra_input_transforms is not None:
|
||||||
|
input_transforms += extra_input_transforms
|
||||||
|
logging.info(f"Input transforms: {input_transforms}")
|
||||||
|
|
||||||
logging.info("About to create train dataset")
|
logging.info("About to create train dataset")
|
||||||
train = K2SpeechRecognitionDataset(
|
train = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms,
|
||||||
|
|||||||
@ -629,9 +629,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
if torch.equal(query, key) and torch.equal(key, value):
|
if torch.equal(query, key) and torch.equal(key, value):
|
||||||
# self-attention
|
# self-attention
|
||||||
q, k, v = nn.functional.linear(
|
q, k, v = (
|
||||||
query, in_proj_weight, in_proj_bias
|
nn.functional.linear(query, in_proj_weight, in_proj_bias)
|
||||||
).chunk(3, dim=-1)
|
.relu()
|
||||||
|
.chunk(3, dim=-1)
|
||||||
|
)
|
||||||
|
|
||||||
elif torch.equal(key, value):
|
elif torch.equal(key, value):
|
||||||
# encoder-decoder attention
|
# encoder-decoder attention
|
||||||
@ -642,7 +644,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b).relu()
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
_start = embed_dim
|
_start = embed_dim
|
||||||
@ -650,7 +652,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
k, v = nn.functional.linear(key, _w, _b).relu().chunk(2, dim=-1)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
@ -660,7 +662,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
q = nn.functional.linear(query, _w, _b)
|
q = nn.functional.linear(query, _w, _b).relu()
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -669,7 +671,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:_end, :]
|
_w = in_proj_weight[_start:_end, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:_end]
|
_b = _b[_start:_end]
|
||||||
k = nn.functional.linear(key, _w, _b)
|
k = nn.functional.linear(key, _w, _b).relu()
|
||||||
|
|
||||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||||
_b = in_proj_bias
|
_b = in_proj_bias
|
||||||
@ -678,7 +680,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_w = in_proj_weight[_start:, :]
|
_w = in_proj_weight[_start:, :]
|
||||||
if _b is not None:
|
if _b is not None:
|
||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b).relu()
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@ -441,7 +441,9 @@ def main():
|
|||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(
|
||||||
|
average_checkpoints(filenames, device=device), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
84
egs/librispeech/ASR/transducer_stateless/frame_shift.py
Normal file
84
egs/librispeech/ASR/transducer_stateless/frame_shift.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse.utils import LOG_EPSILON
|
||||||
|
|
||||||
|
|
||||||
|
def apply_frame_shift(
|
||||||
|
features: torch.Tensor,
|
||||||
|
supervision_segments: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply random frame shift along the time axis.
|
||||||
|
|
||||||
|
For instance, for the input frame `[a, b, c, d]`,
|
||||||
|
|
||||||
|
- If frame shift is 0, the resulting output is `[a, b, c, d]`
|
||||||
|
- If frame shift is -1, the resulting output is `[b, c, d, a]`
|
||||||
|
- If frame shift is 1, the resulting output is `[d, a, b, c]`
|
||||||
|
- If frame shift is 2, the resulting output is `[c, d, a, b]`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
supervision_segments:
|
||||||
|
A 2-D tensor of shape (num_seqs, 3). The first column is
|
||||||
|
`sequence_idx`, the second column is `start_frame`, and
|
||||||
|
the third column is `num_frames`.
|
||||||
|
Returns:
|
||||||
|
Return a 3-D tensor of shape (N, T, C).
|
||||||
|
"""
|
||||||
|
# We assume the subsampling_factor is 4. If you change the
|
||||||
|
# subsampling_factor, you should also change the following
|
||||||
|
# list accordingly
|
||||||
|
#
|
||||||
|
# The value in frame_shifts is selected in such a way that
|
||||||
|
# "value % subsampling_factor" is not duplicated in frame_shifts.
|
||||||
|
frame_shifts = [-1, 0, 1, 2]
|
||||||
|
|
||||||
|
N = features.size(0)
|
||||||
|
|
||||||
|
# We don't support cut concatenation here
|
||||||
|
assert torch.all(
|
||||||
|
torch.eq(supervision_segments[:, 0], torch.arange(N))
|
||||||
|
), supervision_segments
|
||||||
|
|
||||||
|
ans = []
|
||||||
|
for i in range(N):
|
||||||
|
start = supervision_segments[i, 1]
|
||||||
|
end = start + supervision_segments[i, 2]
|
||||||
|
|
||||||
|
feat = features[i, start:end, :]
|
||||||
|
|
||||||
|
r = torch.randint(low=0, high=len(frame_shifts), size=(1,)).item()
|
||||||
|
frame_shift = frame_shifts[r]
|
||||||
|
|
||||||
|
# You can enable the following debug statement
|
||||||
|
# and run ./transducer_stateless/test_frame_shift.py to
|
||||||
|
# view the debug output.
|
||||||
|
# print("frame_shift", frame_shift)
|
||||||
|
|
||||||
|
feat = torch.roll(feat, shifts=frame_shift, dims=0)
|
||||||
|
ans.append(feat)
|
||||||
|
|
||||||
|
ans = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
ans,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=LOG_EPSILON,
|
||||||
|
)
|
||||||
|
assert features.shape == ans.shape
|
||||||
|
|
||||||
|
return ans
|
||||||
@ -79,7 +79,10 @@ class Transducer(nn.Module):
|
|||||||
modified_transducer_prob:
|
modified_transducer_prob:
|
||||||
The probability to use modified transducer loss.
|
The probability to use modified transducer loss.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return a tuple containing:
|
||||||
|
- the transducer loss, a tensor containing only one entry
|
||||||
|
- encoder_out, a tensor of shape (N, num_frames, encoder_out_dim)
|
||||||
|
- encoder_out_lens, a tensor of shape (N,)
|
||||||
"""
|
"""
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
@ -140,4 +143,8 @@ class Transducer(nn.Module):
|
|||||||
from_log_softmax=False,
|
from_log_softmax=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return (
|
||||||
|
loss,
|
||||||
|
encoder_out,
|
||||||
|
x_lens,
|
||||||
|
)
|
||||||
|
|||||||
70
egs/librispeech/ASR/transducer_stateless/test_frame_shift.py
Executable file
70
egs/librispeech/ASR/transducer_stateless/test_frame_shift.py
Executable file
@ -0,0 +1,70 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
To run this file, do:
|
||||||
|
|
||||||
|
cd icefall/egs/librispeech/ASR
|
||||||
|
python ./transducer_stateless/test_frame_shift.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from frame_shift import apply_frame_shift
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_frame_shift():
|
||||||
|
features = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[1, 2, 5],
|
||||||
|
[2, 6, 9],
|
||||||
|
[3, 0, 2],
|
||||||
|
[4, 11, 13],
|
||||||
|
[0, 0, 0],
|
||||||
|
[0, 0, 0],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1, 3, 9],
|
||||||
|
[2, 5, 8],
|
||||||
|
[3, 3, 6],
|
||||||
|
[4, 0, 3],
|
||||||
|
[5, 1, 2],
|
||||||
|
[6, 6, 6],
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
supervision_segments = torch.tensor(
|
||||||
|
[
|
||||||
|
[0, 0, 4],
|
||||||
|
[1, 0, 6],
|
||||||
|
],
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
shifted_features = apply_frame_shift(features, supervision_segments)
|
||||||
|
|
||||||
|
# You can enable the debug statement in frame_shift.py
|
||||||
|
# and check the resulting shifted_features. I've verified
|
||||||
|
# manually that it is correct.
|
||||||
|
print(shifted_features)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_apply_frame_shift()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -46,6 +46,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
|
from frame_shift import apply_frame_shift
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
@ -149,6 +150,21 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply-frame-shift",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If enabled, apply random frame shift along the time axis",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-weight",
|
||||||
|
type=float,
|
||||||
|
default=0.25,
|
||||||
|
help="""If not zero, the total loss is:
|
||||||
|
(1 - ctc_weight) * transdcuder_loss + ctc_weight * ctc_loss
|
||||||
|
""",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -217,6 +233,13 @@ def get_params() -> AttributeDict:
|
|||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"warm_step": 80000, # For the 100h subset, use 8k
|
"warm_step": 80000, # For the 100h subset, use 8k
|
||||||
|
#
|
||||||
|
# parameters for ctc_loss, used only when ctc_weight > 0
|
||||||
|
"modified_ctc_topo": False,
|
||||||
|
"beam_size": 10,
|
||||||
|
"reduction": "sum",
|
||||||
|
"use_double_scores": True,
|
||||||
|
#
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -270,6 +293,17 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def get_ctc_model(params: AttributeDict):
|
||||||
|
if params.ctc_weight > 0:
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
nn.Linear(params.encoder_out_dim, params.vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -390,16 +424,55 @@ def compute_loss(
|
|||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
token_ids = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(token_ids).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(
|
transducer_loss, encoder_out, encoder_out_lens = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
modified_transducer_prob=params.modified_transducer_prob,
|
modified_transducer_prob=params.modified_transducer_prob,
|
||||||
)
|
)
|
||||||
|
loss = transducer_loss
|
||||||
|
|
||||||
|
if params.ctc_weight > 0:
|
||||||
|
ctc_model = (
|
||||||
|
model.module.ctc if hasattr(model, "module") else model.ctc
|
||||||
|
)
|
||||||
|
ctc_graph = k2.ctc_graph(
|
||||||
|
token_ids, modified=params.modified_ctc_topo, device=device
|
||||||
|
)
|
||||||
|
# Note: We assume `encoder_out_lens` is sorted in descending order.
|
||||||
|
# If not, it will throw in k2.ctc_loss().
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
[
|
||||||
|
torch.arange(encoder_out.size(0), dtype=torch.int32),
|
||||||
|
torch.zeros(encoder_out.size(0), dtype=torch.int32),
|
||||||
|
encoder_out_lens.cpu(),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
).to(torch.int32)
|
||||||
|
nnet_out = ctc_model(encoder_out)
|
||||||
|
|
||||||
|
dense_fsa_vec = k2.DenseFsaVec(
|
||||||
|
nnet_out,
|
||||||
|
supervision_segments,
|
||||||
|
allow_truncate=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: transducer_loss should use the same reduction as ctc_loss
|
||||||
|
ctc_loss = k2.ctc_loss(
|
||||||
|
decoding_graph=ctc_graph,
|
||||||
|
dense_fsa_vec=dense_fsa_vec,
|
||||||
|
output_beam=params.beam_size,
|
||||||
|
reduction=params.reduction,
|
||||||
|
use_double_scores=params.use_double_scores,
|
||||||
|
)
|
||||||
|
assert ctc_loss.requires_grad == is_training
|
||||||
|
loss = (
|
||||||
|
1 - params.ctc_weight
|
||||||
|
) * transducer_loss + params.ctc_weight * ctc_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -408,6 +481,9 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
info["transducer_loss"] = transducer_loss.detach().cpu().item()
|
||||||
|
if params.ctc_weight > 0:
|
||||||
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -590,6 +666,11 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
model.ctc = get_ctc_model(params)
|
||||||
|
if model.ctc is not None:
|
||||||
|
logging.info(f"Enable CTC loss with weight: {params.ctc_weight}")
|
||||||
|
else:
|
||||||
|
logging.info("Disable CTC loss")
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
@ -636,7 +717,17 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"After removing short and long utterances: {num_left}")
|
logging.info(f"After removing short and long utterances: {num_left}")
|
||||||
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)")
|
||||||
|
|
||||||
train_dl = librispeech.train_dataloaders(train_cuts)
|
if params.apply_frame_shift:
|
||||||
|
logging.info("Enable random frame shift")
|
||||||
|
extra_input_transforms = [apply_frame_shift]
|
||||||
|
else:
|
||||||
|
logging.info("Disable random frame shift")
|
||||||
|
extra_input_transforms = None
|
||||||
|
|
||||||
|
train_dl = librispeech.train_dataloaders(
|
||||||
|
train_cuts,
|
||||||
|
extra_input_transforms=extra_input_transforms,
|
||||||
|
)
|
||||||
|
|
||||||
valid_cuts = librispeech.dev_clean_cuts()
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user