Skip to content

Commit

Permalink
fix minor bug in natural language description system. also add natura…
Browse files Browse the repository at this point in the history
…l language description to the tutorial.
  • Loading branch information
leix28 committed Feb 5, 2018
1 parent dd88524 commit 0df4e65
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
26 changes: 21 additions & 5 deletions Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,28 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['For each taxi_id, predict the first fare, after trip_id 0.',\n",
" 'For each taxi_id, predict the first fare, after trip_id 0.',\n",
" 'For each taxi_id, predict the first fare, after trip_id 0.']"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prediction_problems_json = trane.prediction_problems_to_json_file(\n",
" probs, table_meta, entity_id_column, label_generating_column, time_column, \"prediction_problems.json\")\n",
"\n",
"trane.generate_nl_description(\n",
" probs, table_meta, entity_id_column, label_generating_column, time_column, trane.ConstantIntegerCutoffTimes(0))\n",
"\n",
"# with open(\"prediction_problems.json\", \"w\") as f:\n",
"# json.dump(json.loads(prediction_problems_json), f, indent=4, separators=(',', ': '))"
]
Expand Down Expand Up @@ -301,7 +315,9 @@
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"entity_to_data_and_cutoff_dict = trane.ConstantIntegerCutoffTimes(0).generate_cutoffs(entity_to_data_dict)"
Expand Down
8 changes: 4 additions & 4 deletions trane/utils/generate_nl_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generate_nl_description(
"""

def description(prob):
return "For each {col}, predict{dataop_des},{filter_des},{cutoff_des}.".format(
return "For each {col}, predict{dataop_des}{filter_des}{cutoff_des}.".format(
col=entity_id_column,
dataop_des=dataop_description(prob),
filter_des=filter_description(prob),
Expand Down Expand Up @@ -97,9 +97,9 @@ def aggop_description():

def cutoff_description(prob):
if isinstance(cutoff_strategy, ConstantIntegerCutoffTimes):
return " after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.integer_cutoff)
return ", after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.integer_cutoff)
elif isinstance(cutoff_strategy, ConstantDatetimeCutoffTimes):
return " after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.datetime_cutoff)
return ", after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.datetime_cutoff)
else:
raise " (unknown cutoff time)"

Expand All @@ -113,7 +113,7 @@ def filter_description(prob):
filter_op = prob.operations[0]
if isinstance(filter_op, AllFilterOp):
return ""
return " with {col} {op} {threshold}".format(
return ", with {col} {op} {threshold}".format(
col=filter_op.column_name,
op=filter_op_str_dict[type(filter_op)],
threshold=filter_op.param_values['threshold'])
Expand Down

0 comments on commit 0df4e65

Please sign in to comment.