diff --git a/Tutorial.ipynb b/Tutorial.ipynb index 26e3f6c5..d42380e7 100644 --- a/Tutorial.ipynb +++ b/Tutorial.ipynb @@ -95,14 +95,28 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "collapsed": true - }, - "outputs": [], + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['For each taxi_id, predict the first fare, after trip_id 0.',\n", + " 'For each taxi_id, predict the first fare, after trip_id 0.',\n", + " 'For each taxi_id, predict the first fare, after trip_id 0.']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "prediction_problems_json = trane.prediction_problems_to_json_file(\n", " probs, table_meta, entity_id_column, label_generating_column, time_column, \"prediction_problems.json\")\n", "\n", + "trane.generate_nl_description(\n", + " probs, table_meta, entity_id_column, label_generating_column, time_column, trane.ConstantIntegerCutoffTimes(0))\n", + "\n", "# with open(\"prediction_problems.json\", \"w\") as f:\n", "# json.dump(json.loads(prediction_problems_json), f, indent=4, separators=(',', ': '))" ] @@ -301,7 +315,9 @@ { "cell_type": "code", "execution_count": 6, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "entity_to_data_and_cutoff_dict = trane.ConstantIntegerCutoffTimes(0).generate_cutoffs(entity_to_data_dict)" diff --git a/trane/utils/generate_nl_description.py b/trane/utils/generate_nl_description.py index f1ca9a49..040dd4d0 100644 --- a/trane/utils/generate_nl_description.py +++ b/trane/utils/generate_nl_description.py @@ -36,7 +36,7 @@ def generate_nl_description( """ def description(prob): - return "For each {col}, predict{dataop_des},{filter_des},{cutoff_des}.".format( + return "For each {col}, predict{dataop_des}{filter_des}{cutoff_des}.".format( col=entity_id_column, dataop_des=dataop_description(prob), filter_des=filter_description(prob), @@ -97,9 +97,9 @@ def aggop_description(): def cutoff_description(prob): if isinstance(cutoff_strategy, ConstantIntegerCutoffTimes): - return " after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.integer_cutoff) + return ", after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.integer_cutoff) elif isinstance(cutoff_strategy, ConstantDatetimeCutoffTimes): - return " after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.datetime_cutoff) + return ", after {col} {cutoff}".format(col=time_column, cutoff=cutoff_strategy.datetime_cutoff) else: raise " (unknown cutoff time)" @@ -113,7 +113,7 @@ def filter_description(prob): filter_op = prob.operations[0] if isinstance(filter_op, AllFilterOp): return "" - return " with {col} {op} {threshold}".format( + return ", with {col} {op} {threshold}".format( col=filter_op.column_name, op=filter_op_str_dict[type(filter_op)], threshold=filter_op.param_values['threshold'])