fix typos

This commit is contained in:
yaozengwei 2023-06-14 11:26:51 +08:00
parent 11ea660c86
commit ac6c894391
6 changed files with 29 additions and 42 deletions

View File

@ -97,7 +97,7 @@ def read_sound_files(
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
ans.append(wave[0].contiguous())
return ans

View File

@ -159,11 +159,11 @@ def get_parser():
(2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with
the highest score is the decoding result.
We call it HLG decoding + n-gram LM rescoring.
We call it HLG decoding + nbest n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
We call it HLG decoding + whole-lattice n-gram LM rescoring.
""",
)
@ -210,15 +210,6 @@ def get_parser():
""",
)
parser.add_argument(
"--num-classes",
type=int,
default=500,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
@ -258,7 +249,7 @@ def read_sound_files(
sample_rate == expected_sample_rate
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
ans.append(wave[0].contiguous())
return ans
@ -272,6 +263,11 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
@ -321,9 +317,7 @@ def main():
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
max_token_id = params.vocab_size - 1
H = k2.ctc_topo(
max_token=max_token_id,
@ -346,7 +340,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = sp.decode(token_ids)
hyps = [s.split() for s in hyps]
elif params.method in [
"1best",

View File

@ -75,7 +75,7 @@ class AsrModel(nn.Module):
assert (
use_transducer or use_ctc
), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}"
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -98,6 +98,9 @@ class AsrModel(nn.Module):
self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25
)
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc
if use_ctc:
@ -135,7 +138,7 @@ class AsrModel(nn.Module):
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
@ -342,10 +345,7 @@ class AsrModel(nn.Module):
if self.use_ctc:
# Compute CTC loss
targets = [t for tokens in y.tolist() for t in tokens]
# of shape (sum(y_lens),)
targets = torch.tensor(targets, device=x.device, dtype=torch.int64)
targets = y.values
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,

View File

@ -245,7 +245,7 @@ def read_sound_files(
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
ans.append(wave[0].contiguous())
return ans

View File

@ -167,11 +167,11 @@ def get_parser():
(2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with
the highest score is the decoding result.
We call it HLG decoding + n-gram LM rescoring.
We call it HLG decoding + nbest n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
We call it HLG decoding + whole-lattice n-gram LM rescoring.
""",
)
@ -218,15 +218,6 @@ def get_parser():
""",
)
parser.add_argument(
"--num-classes",
type=int,
default=500,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
@ -268,7 +259,7 @@ def read_sound_files(
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
ans.append(wave[0].contiguous())
return ans
@ -281,7 +272,11 @@ def main():
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
params.vocab_size = params.num_classes
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.vocab_size = sp.get_piece_size()
params.blank_id = 0
logging.info(f"{params}")
@ -340,9 +335,7 @@ def main():
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
max_token_id = params.vocab_size - 1
H = k2.ctc_topo(
max_token=max_token_id,
@ -365,7 +358,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = sp.decode(token_ids)
hyps = [s.split() for s in hyps]
elif params.method in [
"1best",

View File

@ -607,7 +607,7 @@ def get_model(params: AttributeDict) -> nn.Module:
assert (
params.use_transducer or params.use_ctc
), (f"At least one of them should be True, "
f"but gotten params.use_transducer={params.use_transducer}, "
f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}")
encoder_embed = get_encoder_embed(params)