update codes

This commit is contained in:
luomingshuang 2022-03-11 13:44:46 +08:00
parent 4a725a5eec
commit 396aaefbaa
7 changed files with 88 additions and 84 deletions

View File

@ -90,6 +90,7 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
print(pieces, token2id)
pieces = [token2id[i] for i in pieces] pieces = [token2id[i] for i in pieces]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):

View File

@ -16,57 +16,35 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int): def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
super().__init__() super().__init__()
self.input_dim = input_dim self.inner_linear = nn.Linear(input_dim, inner_dim)
self.output_dim = output_dim self.output_linear = nn.Linear(inner_dim, output_dim)
self.output_linear = nn.Linear(input_dim, output_dim)
def forward( def forward(
self, self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
encoder_out_len: torch.Tensor,
decoder_out_len: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, self.input_dim). Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out: decoder_out:
Output from the decoder. Its shape is (N, U, self.input_dim). Output from the decoder. Its shape is (N, T, s_range, C).
Returns: Returns:
Return a tensor of shape (sum_all_TU, self.output_dim). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.size(0) == decoder_out.size(0) assert encoder_out.shape == decoder_out.shape
assert encoder_out.size(2) == self.input_dim
assert decoder_out.size(2) == self.input_dim
N = encoder_out.size(0) logit = encoder_out + decoder_out
encoder_out_list = [ logit = self.inner_linear(torch.tanh(logit))
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
]
decoder_out_list = [ output = self.output_linear(F.relu(logit))
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
]
x = [ return output
e.unsqueeze(1) + d.unsqueeze(0)
for e, d in zip(encoder_out_list, decoder_out_list)
]
x = [p.reshape(-1, self.input_dim) for p in x]
x = torch.cat(x)
activations = torch.tanh(x)
logits = self.output_linear(activations)
return logits

View File

@ -1,4 +1,4 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import k2 import k2
import torch import torch
@ -64,7 +63,9 @@ class Transducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
modified_transducer_prob: float = 0.0, prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -76,10 +77,23 @@ class Transducer(nn.Module):
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
modified_transducer_prob: prune_range:
The probability to use modified transducer loss. The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Returns: Returns:
Return the transducer loss. Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
""" """
assert x.ndim == 3, x.shape assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
@ -97,47 +111,59 @@ class Transducer(nn.Module):
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id) sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
# decoder_out: [B, S + 1, C]
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
# +1 here since a blank is prepended to each utterance.
logits = self.joiner(
encoder_out=encoder_out,
decoder_out=decoder_out,
encoder_out_len=x_lens,
decoder_out_len=y_lens + 1,
)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS # Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
# We don't put this `import` at the beginning of the file y_padded = y_padded.to(torch.int64)
# as it is required only in the training, not during the boundary = torch.zeros(
# reference stage (x.size(0), 4), dtype=torch.int64, device=x.device
import optimized_transducer )
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
assert 0 <= modified_transducer_prob <= 1 simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
if modified_transducer_prob == 0: am=encoder_out,
one_sym_per_frame = False symbols=y_padded,
elif random.random() < modified_transducer_prob: termination_symbol=blank_id,
# random.random() returns a float in the range [0, 1) lm_only_scale=lm_scale,
one_sym_per_frame = True am_only_scale=am_scale,
else: boundary=boundary,
one_sym_per_frame = False
loss = optimized_transducer.transducer_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum", reduction="sum",
one_sym_per_frame=one_sym_per_frame, return_grad=True,
from_log_softmax=False,
) )
return loss # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
)
# logits : [B, T, prune_range, C]
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -71,6 +71,8 @@ def compute_fbank_tedlium():
recordings=m["recordings"], recordings=m["recordings"],
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
# Split long cuts into many short and un-overlapping cuts
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
if "train" in partition: if "train" in partition:
cut_set = ( cut_set = (
cut_set cut_set
@ -85,8 +87,6 @@ def compute_fbank_tedlium():
executor=ex, executor=ex,
storage_type=ChunkedLilcomHdf5Writer, storage_type=ChunkedLilcomHdf5Writer,
) )
# Split long cuts into many short and un-overlapping cuts
cut_set = cut_set.trim_to_supervisions(keep_overlapping=False)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")

View File

@ -29,9 +29,9 @@ dl_dir=$PWD/download
# It will generate data/lang_bpe_xxx, # It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy # data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=( vocab_sizes=(
5000 #5000
2000 #2000
1000 #1000
500 500
) )

View File

@ -71,7 +71,7 @@ class TedLiumAsrDataModule:
group.add_argument( group.add_argument(
"--manifest-dir", "--manifest-dir",
type=Path, type=Path,
default=Path("data/fbank"), default=Path("data/fbank_overlap_false"),
help="Path to directory with train/valid/test cuts.", help="Path to directory with train/valid/test cuts.",
) )
group.add_argument( group.add_argument(
@ -348,7 +348,6 @@ class TedLiumAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
print(self.args.manifest_dir)
return load_manifest(self.args.manifest_dir / "cuts_train.json.gz") return load_manifest(self.args.manifest_dir / "cuts_train.json.gz")
@lru_cache() @lru_cache()

View File

@ -77,7 +77,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--master-port", "--master-port",
type=int, type=int,
default=12354, default=12350,
help="Master port to use for DDP training.", help="Master port to use for DDP training.",
) )
@ -108,7 +108,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless/exp", default="pruned_transducer_stateless/exp-4-gpus-300",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -658,7 +658,7 @@ def run(rank, world_size, args):
# Keep only utterances with duration between 1 second and max seconds # Keep only utterances with duration between 1 second and max seconds
# Here, we set max as 20.0. # Here, we set max as 20.0.
# If you want to use a big max-duration, you can set it as 17.0. # If you want to use a big max-duration, you can set it as 17.0.
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 17.0
num_in_total = len(train_cuts) num_in_total = len(train_cuts)