fix typos

This commit is contained in:
yaozengwei 2023-06-14 11:26:51 +08:00
parent 11ea660c86
commit ac6c894391
6 changed files with 29 additions and 42 deletions

View File

@ -97,7 +97,7 @@ def read_sound_files(
sample_rate == expected_sample_rate sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0].contiguous())
return ans return ans

View File

@ -159,11 +159,11 @@ def get_parser():
(2) nbest-rescoring. Extract n paths from the decoding lattice, (2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with rescore them with an LM, the path with
the highest score is the decoding result. the highest score is the decoding result.
We call it HLG decoding + n-gram LM rescoring. We call it HLG decoding + nbest n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the (3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the decoding lattice and then use 1best to decode the
rescored lattice. rescored lattice.
We call it HLG decoding + n-gram LM rescoring. We call it HLG decoding + whole-lattice n-gram LM rescoring.
""", """,
) )
@ -210,15 +210,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-classes",
type=int,
default=500,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument( parser.add_argument(
"--sample-rate", "--sample-rate",
type=int, type=int,
@ -258,7 +249,7 @@ def read_sound_files(
sample_rate == expected_sample_rate sample_rate == expected_sample_rate
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0].contiguous())
return ans return ans
@ -272,6 +263,11 @@ def main():
params.update(get_decoding_params()) params.update(get_decoding_params())
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -321,9 +317,7 @@ def main():
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor() max_token_id = params.vocab_size - 1
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
@ -346,7 +340,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids) hyps = sp.decode(token_ids)
hyps = [s.split() for s in hyps] hyps = [s.split() for s in hyps]
elif params.method in [ elif params.method in [
"1best", "1best",

View File

@ -75,7 +75,7 @@ class AsrModel(nn.Module):
assert ( assert (
use_transducer or use_ctc use_transducer or use_ctc
), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}" ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
@ -98,6 +98,9 @@ class AsrModel(nn.Module):
self.simple_lm_proj = ScaledLinear( self.simple_lm_proj = ScaledLinear(
decoder_dim, vocab_size, initial_scale=0.25 decoder_dim, vocab_size, initial_scale=0.25
) )
else:
assert decoder is None
assert joiner is None
self.use_ctc = use_ctc self.use_ctc = use_ctc
if use_ctc: if use_ctc:
@ -135,7 +138,7 @@ class AsrModel(nn.Module):
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0) assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens return encoder_out, encoder_out_lens
@ -342,10 +345,7 @@ class AsrModel(nn.Module):
if self.use_ctc: if self.use_ctc:
# Compute CTC loss # Compute CTC loss
targets = [t for tokens in y.tolist() for t in tokens] targets = y.values
# of shape (sum(y_lens),)
targets = torch.tensor(targets, device=x.device, dtype=torch.int64)
ctc_loss = self.forward_ctc( ctc_loss = self.forward_ctc(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,

View File

@ -245,7 +245,7 @@ def read_sound_files(
sample_rate == expected_sample_rate sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0].contiguous())
return ans return ans

View File

@ -167,11 +167,11 @@ def get_parser():
(2) nbest-rescoring. Extract n paths from the decoding lattice, (2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with rescore them with an LM, the path with
the highest score is the decoding result. the highest score is the decoding result.
We call it HLG decoding + n-gram LM rescoring. We call it HLG decoding + nbest n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the (3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the decoding lattice and then use 1best to decode the
rescored lattice. rescored lattice.
We call it HLG decoding + n-gram LM rescoring. We call it HLG decoding + whole-lattice n-gram LM rescoring.
""", """,
) )
@ -218,15 +218,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num-classes",
type=int,
default=500,
help="""
Vocab size in the BPE model.
""",
)
parser.add_argument( parser.add_argument(
"--sample-rate", "--sample-rate",
type=int, type=int,
@ -268,7 +259,7 @@ def read_sound_files(
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
) )
# We use only the first channel # We use only the first channel
ans.append(wave[0]) ans.append(wave[0].contiguous())
return ans return ans
@ -281,7 +272,11 @@ def main():
# add decoding params # add decoding params
params.update(get_decoding_params()) params.update(get_decoding_params())
params.update(vars(args)) params.update(vars(args))
params.vocab_size = params.num_classes
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
params.vocab_size = sp.get_piece_size()
params.blank_id = 0 params.blank_id = 0
logging.info(f"{params}") logging.info(f"{params}")
@ -340,9 +335,7 @@ def main():
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor() max_token_id = params.vocab_size - 1
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
@ -365,7 +358,7 @@ def main():
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids) hyps = sp.decode(token_ids)
hyps = [s.split() for s in hyps] hyps = [s.split() for s in hyps]
elif params.method in [ elif params.method in [
"1best", "1best",

View File

@ -607,7 +607,7 @@ def get_model(params: AttributeDict) -> nn.Module:
assert ( assert (
params.use_transducer or params.use_ctc params.use_transducer or params.use_ctc
), (f"At least one of them should be True, " ), (f"At least one of them should be True, "
f"but gotten params.use_transducer={params.use_transducer}, " f"but got params.use_transducer={params.use_transducer}, "
f"params.use_ctc={params.use_ctc}") f"params.use_ctc={params.use_ctc}")
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)