mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
update codes
This commit is contained in:
parent
4a725a5eec
commit
396aaefbaa
@ -90,6 +90,7 @@ def lexicon_to_fst_no_sil(
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
print(pieces, token2id)
|
||||
pieces = [token2id[i] for i in pieces]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
|
@ -16,57 +16,35 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||
self.inner_linear = nn.Linear(input_dim, inner_dim)
|
||||
self.output_linear = nn.Linear(inner_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
decoder_out: torch.Tensor,
|
||||
encoder_out_len: torch.Tensor,
|
||||
decoder_out_len: torch.Tensor,
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||
Output from the encoder. Its shape is (N, T, s_range, C).
|
||||
decoder_out:
|
||||
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||
Output from the decoder. Its shape is (N, T, s_range, C).
|
||||
Returns:
|
||||
Return a tensor of shape (sum_all_TU, self.output_dim).
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||
assert encoder_out.size(0) == decoder_out.size(0)
|
||||
assert encoder_out.size(2) == self.input_dim
|
||||
assert decoder_out.size(2) == self.input_dim
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
N = encoder_out.size(0)
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
encoder_out_list = [
|
||||
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||
]
|
||||
logit = self.inner_linear(torch.tanh(logit))
|
||||
|
||||
decoder_out_list = [
|
||||
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
|
||||
]
|
||||
output = self.output_linear(F.relu(logit))
|
||||
|
||||
x = [
|
||||
e.unsqueeze(1) + d.unsqueeze(0)
|
||||
for e, d in zip(encoder_out_list, decoder_out_list)
|
||||
]
|
||||
|
||||
x = [p.reshape(-1, self.input_dim) for p in x]
|
||||
x = torch.cat(x)
|
||||
|
||||
activations = torch.tanh(x)
|
||||
|
||||
logits = self.output_linear(activations)
|
||||
|
||||
return logits
|
||||
return output
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -64,7 +63,9 @@ class Transducer(nn.Module):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
modified_transducer_prob: float = 0.0,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -76,10 +77,23 @@ class Transducer(nn.Module):
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
modified_transducer_prob:
|
||||
The probability to use modified transducer loss.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
@ -97,47 +111,59 @@ class Transducer(nn.Module):
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
sos_y_padded = sos_y_padded.to(torch.int64)
|
||||
|
||||
# decoder_out: [B, S + 1, C]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# +1 here since a blank is prepended to each utterance.
|
||||
logits = self.joiner(
|
||||
encoder_out=encoder_out,
|
||||
decoder_out=decoder_out,
|
||||
encoder_out_len=x_lens,
|
||||
decoder_out_len=y_lens + 1,
|
||||
)
|
||||
|
||||
# rnnt_loss requires 0 padded targets
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
# We don't put this `import` at the beginning of the file
|
||||
# as it is required only in the training, not during the
|
||||
# reference stage
|
||||
import optimized_transducer
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
assert 0 <= modified_transducer_prob <= 1
|
||||
|
||||
if modified_transducer_prob == 0:
|
||||
one_sym_per_frame = False
|
||||
elif random.random() < modified_transducer_prob:
|
||||
# random.random() returns a float in the range [0, 1)
|
||||
one_sym_per_frame = True
|
||||
else:
|
||||
one_sym_per_frame = False
|
||||
|
||||
loss = optimized_transducer.transducer_loss(
|
||||
logits=logits,
|
||||
targets=y_padded,
|
||||
logit_lengths=x_lens,
|
||||
target_lengths=y_lens,
|
||||
blank=blank_id,
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=decoder_out,
|
||||
am=encoder_out,
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
one_sym_per_frame=one_sym_per_frame,
|
||||
from_log_softmax=False,
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
return loss
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, C]
|
||||
# lm_pruned : [B, T, prune_range, C]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, C]
|
||||
logits = self.joiner(am_pruned, lm_pruned)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits,
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
@ -71,6 +71,8 @@ def compute_fbank_tedlium():
|
||||
recordings=m["recordings"],
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
# Split long cuts into many short and un-overlapping cuts
|
||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||
if "train" in partition:
|
||||
cut_set = (
|
||||
cut_set
|
||||
@ -85,8 +87,6 @@ def compute_fbank_tedlium():
|
||||
executor=ex,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
)
|
||||
# Split long cuts into many short and un-overlapping cuts
|
||||
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
|
||||
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
|
||||
|
||||
|
||||
|
@ -29,9 +29,9 @@ dl_dir=$PWD/download
|
||||
# It will generate data/lang_bpe_xxx,
|
||||
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||
vocab_sizes=(
|
||||
5000
|
||||
2000
|
||||
1000
|
||||
#5000
|
||||
#2000
|
||||
#1000
|
||||
500
|
||||
)
|
||||
|
||||
|
@ -71,7 +71,7 @@ class TedLiumAsrDataModule:
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
default=Path("data/fbank_overlap_false"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
@ -348,7 +348,6 @@ class TedLiumAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
print(self.args.manifest_dir)
|
||||
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
|
||||
|
||||
@lru_cache()
|
||||
|
@ -77,7 +77,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
default=12350,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
@ -108,7 +108,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless/exp",
|
||||
default="pruned_transducer_stateless/exp-4-gpus-300",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -658,7 +658,7 @@ def run(rank, world_size, args):
|
||||
# Keep only utterances with duration between 1 second and max seconds
|
||||
# Here, we set max as 20.0.
|
||||
# If you want to use a big max-duration, you can set it as 17.0.
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
return 1.0 <= c.duration <= 17.0
|
||||
|
||||
num_in_total = len(train_cuts)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user