update codes

This commit is contained in:
luomingshuang 2022-03-11 13:44:46 +08:00
parent 4a725a5eec
commit 396aaefbaa
7 changed files with 88 additions and 84 deletions

View File

@ -90,6 +90,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
print(pieces, token2id)
pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1):

View File

@ -16,57 +16,35 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.output_linear = nn.Linear(input_dim, output_dim)
self.inner_linear = nn.Linear(input_dim, inner_dim)
self.output_linear = nn.Linear(inner_dim, output_dim)
def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, self.input_dim).
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim).
Output from the decoder. Its shape is (N, T, s_range, C).
Returns:
Return a tensor of shape (sum_all_TU, self.output_dim).
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == self.input_dim
assert decoder_out.size(2) == self.input_dim
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape
N = encoder_out.size(0)
logit = encoder_out + decoder_out
encoder_out_list = [
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
]
logit = self.inner_linear(torch.tanh(logit))
decoder_out_list = [
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
]
output = self.output_linear(F.relu(logit))
x = [
e.unsqueeze(1) + d.unsqueeze(0)
for e, d in zip(encoder_out_list, decoder_out_list)
]
x = [p.reshape(-1, self.input_dim) for p in x]
x = torch.cat(x)
activations = torch.tanh(x)
logits = self.output_linear(activations)
return logits
return output

View File

@ -1,4 +1,4 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import k2
import torch
@ -64,7 +63,9 @@ class Transducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor:
"""
Args:
@ -76,10 +77,23 @@ class Transducer(nn.Module):
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
modified_transducer_prob:
The probability to use modified transducer loss.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
@ -97,47 +111,59 @@ class Transducer(nn.Module):
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
# decoder_out: [B, S + 1, C]
decoder_out = self.decoder(sos_y_padded)
# +1 here since a blank is prepended to each utterance.
logits = self.joiner(
encoder_out=encoder_out,
decoder_out=decoder_out,
encoder_out_len=x_lens,
decoder_out_len=y_lens + 1,
)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
# We don't put this `import` at the beginning of the file
# as it is required only in the training, not during the
# reference stage
import optimized_transducer
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
assert 0 <= modified_transducer_prob <= 1
if modified_transducer_prob == 0:
one_sym_per_frame = False
elif random.random() < modified_transducer_prob:
# random.random() returns a float in the range [0, 1)
one_sym_per_frame = True
else:
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
one_sym_per_frame=one_sym_per_frame,
from_log_softmax=False,
return_grad=True,
)
return loss
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
)
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -71,6 +71,8 @@ def compute_fbank_tedlium():
recordings=m["recordings"],
supervisions=m["supervisions"],
)
# Split long cuts into many short and un-overlapping cuts
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
if "train" in partition:
cut_set = (
cut_set
@ -85,8 +87,6 @@ def compute_fbank_tedlium():
executor=ex,
storage_type=ChunkedLilcomHdf5Writer,
)
# Split long cuts into many short and un-overlapping cuts
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")

View File

@ -29,9 +29,9 @@ dl_dir=$PWD/download
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
5000
2000
1000
#5000
#2000
#1000
500
)

View File

@ -71,7 +71,7 @@ class TedLiumAsrDataModule:
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/fbank"),
default=Path("data/fbank_overlap_false"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
@ -348,7 +348,6 @@ class TedLiumAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
print(self.args.manifest_dir)
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
@lru_cache()

View File

@ -77,7 +77,7 @@ def get_parser():
parser.add_argument(
"--master-port",
type=int,
default=12354,
default=12350,
help="Master port to use for DDP training.",
)
@ -108,7 +108,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless/exp",
default="pruned_transducer_stateless/exp-4-gpus-300",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -658,7 +658,7 @@ def run(rank, world_size, args):
# Keep only utterances with duration between 1 second and max seconds
# Here, we set max as 20.0.
# If you want to use a big max-duration, you can set it as 17.0.
return 1.0 <= c.duration <= 20.0
return 1.0 <= c.duration <= 17.0
num_in_total = len(train_cuts)