-
Notifications
You must be signed in to change notification settings - Fork 0
/
encode.py
42 lines (36 loc) · 1.51 KB
/
encode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# (C) 2018-present Klebert Engineering
import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.realpath(__file__))+"/modules")
from deepspell.corpus import DSCorpus
from deepspell.models.encoder import DSVariationalLstmAutoEncoder
arg_parser = argparse.ArgumentParser("NDS AutoCompletion Quality Evaluator")
arg_parser.add_argument(
"--corpus",
# default="corpora/deepspell_data_north_america_nozip_v2.tsv",
default="corpora/deepspell_minimal.tsv",
help="Path to the corpus from which benchmark samples should be drawn.")
arg_parser.add_argument(
"--encoder",
default="models/deepsp_spell-v1_na-lower_lr003_dec50_bat3072_emb8_fw128-128_bw128_de128-128_drop80.json",
help="Path to the model JSON descriptor that should be used for token encoding.")
arg_parser.add_argument(
"--output-dir", "-o",
dest="output_path",
default="corpora/",
help="Directory path to where the generated embeddings should be stored.")
arg_parser.add_argument(
"--batch-size", "-b",
dest="batch_size",
default=16384,
type=int,
help="Number of samples that should be processed in parallel.")
args = arg_parser.parse_args()
print("Encoding FTS Corpus... ")
print(" ... encoder: "+args.encoder)
print(" ... corpus: "+args.corpus)
print("=======================================================================")
print("")
encoder_model = DSVariationalLstmAutoEncoder(args.encoder, "logs")
encoder_model.encode_corpus(args.corpus, args.output_path, batch_size=args.batch_size)