mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Version that is running...
This commit is contained in:
parent
3bad661f6f
commit
39b6879d72
@ -572,10 +572,7 @@ class BidirectionalConformer(nn.Module):
|
||||
tokens_padded = pad_sequence(token_ids_tensors, batch_first=True,
|
||||
padding_value=padding_id).to(positive_embed_shifted.device)
|
||||
|
||||
print("tokens_padded = ", tokens_padded)
|
||||
tokens_key_padding_mask = decoder_padding_mask(tokens_padded, ignore_id=padding_id)
|
||||
print("tokens_key_padding_mask=", tokens_key_padding_mask)
|
||||
|
||||
|
||||
# Let S be the length of the longest sentence (padded)
|
||||
token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale # (N, S) -> (N, S, C)
|
||||
|
@ -15,6 +15,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#--master-port 12344 --world-size 3 --max-duration=200 --bucketing-sampler=True --start-epoch=5
|
||||
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
@ -22,7 +24,7 @@ import logging
|
||||
from pathlib import Path
|
||||
import random # temp..
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -174,6 +176,7 @@ def get_params() -> AttributeDict:
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_trunk_encoder_layers": 12,
|
||||
"num_ctc_encoder_layers": 2,
|
||||
"num_decoder_layers": 6,
|
||||
"num_reverse_encoder_layers": 4,
|
||||
"num_reverse_decoder_layers": 4,
|
||||
@ -285,7 +288,7 @@ class LossRecord(collections.defaultdict):
|
||||
# makes undefined items default to int() which is zero.
|
||||
super(LossRecord, self).__init__(int)
|
||||
|
||||
def __add__(self, other: LossRecord) -> LossRecord:
|
||||
def __add__(self, other: 'LossRecord') -> 'LossRecord':
|
||||
ans = LossRecord()
|
||||
for k, v in self.items():
|
||||
ans[k] = v
|
||||
@ -293,7 +296,7 @@ class LossRecord(collections.defaultdict):
|
||||
ans[k] = ans[k] + v
|
||||
return ans
|
||||
|
||||
def __mul__(self, alpha: float) -> LossRecord:
|
||||
def __mul__(self, alpha: float) -> 'LossRecord':
|
||||
ans = LossRecord()
|
||||
for k, v in self.items():
|
||||
ans[k] = v * alpha
|
||||
@ -303,13 +306,13 @@ class LossRecord(collections.defaultdict):
|
||||
def __str__(self) -> str:
|
||||
ans = ''
|
||||
for k, v in self.norm_items():
|
||||
norm_value = '%.2g' % v
|
||||
norm_value = '%.4g' % v
|
||||
ans += (str(k) + '=' + str(norm_value) + ', ')
|
||||
frames = str(self['frames'])
|
||||
ans += 'over ' + frames + ' frames.'
|
||||
return ans
|
||||
|
||||
def norm_items(self) -> List[Tuple[string, float]]
|
||||
def norm_items(self) -> List[Tuple[str, float]]:
|
||||
"""
|
||||
Returns a list of pairs, like:
|
||||
[('ctc_loss', 0.1), ('att_loss', 0.07)]
|
||||
@ -320,7 +323,7 @@ class LossRecord(collections.defaultdict):
|
||||
if k != 'frames':
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
"""
|
||||
@ -353,7 +356,7 @@ def compute_loss(
|
||||
batch: dict,
|
||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, LossRecord]
|
||||
) -> Tuple[Tensor, LossRecord]:
|
||||
"""
|
||||
Compute loss function (including CTC, attention, and reverse-attention terms).
|
||||
|
||||
@ -562,7 +565,7 @@ def train_one_epoch(
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = LossInfo()
|
||||
tot_loss = LossRecord()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
@ -679,7 +682,7 @@ def run(rank, world_size, args):
|
||||
num_self_predictor_layers=params.num_self_predictor_layers,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
is_bpe=params.is_bpe,
|
||||
discretization_tot_classes=params.discretization_tot_clases,
|
||||
discretization_tot_classes=params.discretization_tot_classes,
|
||||
discretization_num_groups=params.discretization_num_groups,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user