mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix typos
This commit is contained in:
parent
11ea660c86
commit
ac6c894391
@ -97,7 +97,7 @@ def read_sound_files(
|
|||||||
sample_rate == expected_sample_rate
|
sample_rate == expected_sample_rate
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0].contiguous())
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,11 +159,11 @@ def get_parser():
|
|||||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
rescore them with an LM, the path with
|
rescore them with an LM, the path with
|
||||||
the highest score is the decoding result.
|
the highest score is the decoding result.
|
||||||
We call it HLG decoding + n-gram LM rescoring.
|
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
decoding lattice and then use 1best to decode the
|
decoding lattice and then use 1best to decode the
|
||||||
rescored lattice.
|
rescored lattice.
|
||||||
We call it HLG decoding + n-gram LM rescoring.
|
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -210,15 +210,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-classes",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help="""
|
|
||||||
Vocab size in the BPE model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample-rate",
|
"--sample-rate",
|
||||||
type=int,
|
type=int,
|
||||||
@ -258,7 +249,7 @@ def read_sound_files(
|
|||||||
sample_rate == expected_sample_rate
|
sample_rate == expected_sample_rate
|
||||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0].contiguous())
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -272,6 +263,11 @@ def main():
|
|||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -321,9 +317,7 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
max_token_id = params.vocab_size - 1
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
@ -346,7 +340,7 @@ def main():
|
|||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
hyps = sp.decode(token_ids)
|
||||||
hyps = [s.split() for s in hyps]
|
hyps = [s.split() for s in hyps]
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
|
@ -75,7 +75,7 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
use_transducer or use_ctc
|
use_transducer or use_ctc
|
||||||
), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}"
|
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||||
|
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
@ -98,6 +98,9 @@ class AsrModel(nn.Module):
|
|||||||
self.simple_lm_proj = ScaledLinear(
|
self.simple_lm_proj = ScaledLinear(
|
||||||
decoder_dim, vocab_size, initial_scale=0.25
|
decoder_dim, vocab_size, initial_scale=0.25
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert decoder is None
|
||||||
|
assert joiner is None
|
||||||
|
|
||||||
self.use_ctc = use_ctc
|
self.use_ctc = use_ctc
|
||||||
if use_ctc:
|
if use_ctc:
|
||||||
@ -135,7 +138,7 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
assert torch.all(encoder_out_lens > 0)
|
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
@ -342,10 +345,7 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
if self.use_ctc:
|
if self.use_ctc:
|
||||||
# Compute CTC loss
|
# Compute CTC loss
|
||||||
targets = [t for tokens in y.tolist() for t in tokens]
|
targets = y.values
|
||||||
# of shape (sum(y_lens),)
|
|
||||||
targets = torch.tensor(targets, device=x.device, dtype=torch.int64)
|
|
||||||
|
|
||||||
ctc_loss = self.forward_ctc(
|
ctc_loss = self.forward_ctc(
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
@ -245,7 +245,7 @@ def read_sound_files(
|
|||||||
sample_rate == expected_sample_rate
|
sample_rate == expected_sample_rate
|
||||||
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0].contiguous())
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,11 +167,11 @@ def get_parser():
|
|||||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||||
rescore them with an LM, the path with
|
rescore them with an LM, the path with
|
||||||
the highest score is the decoding result.
|
the highest score is the decoding result.
|
||||||
We call it HLG decoding + n-gram LM rescoring.
|
We call it HLG decoding + nbest n-gram LM rescoring.
|
||||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
decoding lattice and then use 1best to decode the
|
decoding lattice and then use 1best to decode the
|
||||||
rescored lattice.
|
rescored lattice.
|
||||||
We call it HLG decoding + n-gram LM rescoring.
|
We call it HLG decoding + whole-lattice n-gram LM rescoring.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -218,15 +218,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-classes",
|
|
||||||
type=int,
|
|
||||||
default=500,
|
|
||||||
help="""
|
|
||||||
Vocab size in the BPE model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample-rate",
|
"--sample-rate",
|
||||||
type=int,
|
type=int,
|
||||||
@ -268,7 +259,7 @@ def read_sound_files(
|
|||||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||||
)
|
)
|
||||||
# We use only the first channel
|
# We use only the first channel
|
||||||
ans.append(wave[0])
|
ans.append(wave[0].contiguous())
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -281,7 +272,11 @@ def main():
|
|||||||
# add decoding params
|
# add decoding params
|
||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.vocab_size = params.num_classes
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
params.blank_id = 0
|
params.blank_id = 0
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
@ -340,9 +335,7 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
max_token_id = params.vocab_size - 1
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
@ -365,7 +358,7 @@ def main():
|
|||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
hyps = sp.decode(token_ids)
|
||||||
hyps = [s.split() for s in hyps]
|
hyps = [s.split() for s in hyps]
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
|
@ -607,7 +607,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
assert (
|
assert (
|
||||||
params.use_transducer or params.use_ctc
|
params.use_transducer or params.use_ctc
|
||||||
), (f"At least one of them should be True, "
|
), (f"At least one of them should be True, "
|
||||||
f"but gotten params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}")
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user