mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
fix typos
This commit is contained in:
parent
11ea660c86
commit
ac6c894391
@ -97,7 +97,7 @@ def read_sound_files(
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -159,11 +159,11 @@ def get_parser():
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -210,15 +210,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""
|
||||
Vocab size in the BPE model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
@ -258,7 +249,7 @@ def read_sound_files(
|
||||
sample_rate == expected_sample_rate
|
||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@ -272,6 +263,11 @@ def main():
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
@ -321,9 +317,7 @@ def main():
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
max_token_id = params.vocab_size - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
@ -346,7 +340,7 @@ def main():
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = sp.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
|
@ -75,7 +75,7 @@ class AsrModel(nn.Module):
|
||||
|
||||
assert (
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
@ -98,6 +98,9 @@ class AsrModel(nn.Module):
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
assert decoder is None
|
||||
assert joiner is None
|
||||
|
||||
self.use_ctc = use_ctc
|
||||
if use_ctc:
|
||||
@ -135,7 +138,7 @@ class AsrModel(nn.Module):
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
assert torch.all(encoder_out_lens > 0)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
@ -342,10 +345,7 @@ class AsrModel(nn.Module):
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = [t for tokens in y.tolist() for t in tokens]
|
||||
# of shape (sum(y_lens),)
|
||||
targets = torch.tensor(targets, device=x.device, dtype=torch.int64)
|
||||
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
|
@ -245,7 +245,7 @@ def read_sound_files(
|
||||
sample_rate == expected_sample_rate
|
||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
|
@ -167,11 +167,11 @@ def get_parser():
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
@ -218,15 +218,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""
|
||||
Vocab size in the BPE model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
@ -268,7 +259,7 @@ def read_sound_files(
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
ans.append(wave[0].contiguous())
|
||||
return ans
|
||||
|
||||
|
||||
@ -281,7 +272,11 @@ def main():
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
params.vocab_size = params.num_classes
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
params.blank_id = 0
|
||||
|
||||
logging.info(f"{params}")
|
||||
@ -340,9 +335,7 @@ def main():
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
max_token_id = params.vocab_size - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
@ -365,7 +358,7 @@ def main():
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = sp.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
|
@ -607,7 +607,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
assert (
|
||||
params.use_transducer or params.use_ctc
|
||||
), (f"At least one of them should be True, "
|
||||
f"but gotten params.use_transducer={params.use_transducer}, "
|
||||
f"but got params.use_transducer={params.use_transducer}, "
|
||||
f"params.use_ctc={params.use_ctc}")
|
||||
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user