From c0215838dac7f1920815d5831c178ffe9a9c3654 Mon Sep 17 00:00:00 2001 From: a-kore <37000693+a-kore@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:18:07 -0400 Subject: [PATCH] Add ability to truncate report and centralise report generation. (#579) * add last_n_evals arg to backend export * finish slider w/ text * merge main * change to update one json+html * fix notebook path typo * fix notebook path typo for los --------- Co-authored-by: Amrit Krishnan --- cyclops/report/model_card/fields.py | 5 - cyclops/report/model_card/sections.py | 5 + cyclops/report/report.py | 23 +- .../templates/model_report/macros.jinja | 36 +- .../templates/model_report/model_report.jinja | 45 + cyclops/report/templates/model_report/plot.js | 780 +++++++++--------- cyclops/report/utils.py | 26 - .../kaggle/heart_failure_prediction.ipynb | 23 +- .../mimiciv/mortality_prediction.ipynb | 2 +- .../tutorials/synthea/los_prediction.ipynb | 2 +- tests/cyclops/report/test_utils.py | 9 - 11 files changed, 501 insertions(+), 455 deletions(-) diff --git a/cyclops/report/model_card/fields.py b/cyclops/report/model_card/fields.py index d4cdf9c5c..12dd29497 100644 --- a/cyclops/report/model_card/fields.py +++ b/cyclops/report/model_card/fields.py @@ -594,11 +594,6 @@ class MetricCard( description="History of the metric over time.", ) - trend: Optional[StrictStr] = Field( - None, - description="The trend of the metric over time.", - ) - timestamps: Optional[List[StrictStr]] = Field( None, description="Timestamps for each point in the history.", diff --git a/cyclops/report/model_card/sections.py b/cyclops/report/model_card/sections.py index 814d27fbf..a57bf773f 100644 --- a/cyclops/report/model_card/sections.py +++ b/cyclops/report/model_card/sections.py @@ -30,6 +30,11 @@ class Overview(BaseModelCardSection): """Overview section with aggregate metrics.""" + last_n_evals: Optional[int] = Field( + None, + description="The number of evaluations to display in the model card.", + ) + metric_cards: Optional[MetricCardCollection] = Field( None, description="Comparative metrics between baseline and periodic report.", diff --git a/cyclops/report/report.py b/cyclops/report/report.py index 444f8a985..15bf9cfee 100644 --- a/cyclops/report/report.py +++ b/cyclops/report/report.py @@ -51,7 +51,6 @@ get_slices, get_thresholds, get_timestamps, - get_trends, regex_replace, regex_search, str_to_snake_case, @@ -98,7 +97,7 @@ def from_json_file( The path to a JSON file containing model card data. output_dir : str, optional The directory to save the report to. If not provided, the report will - be saved in a directory called `cyclops_reports` in the current working + be saved in a directory called `cyclops_report` in the current working directory. Returns @@ -1054,6 +1053,7 @@ def export( template_path: Optional[str] = None, interactive: bool = True, save_json: bool = True, + last_n_evals: Optional[int] = None, synthetic_timestamp: Optional[str] = None, ) -> str: """Export the model card report to an HTML file. @@ -1070,6 +1070,9 @@ def export( Whether to create an interactive HTML report. The default is True. save_json : bool, optional Whether to save the model card as a JSON file. The default is True. + last_n_evals : int, optional + The number of most recent evaluations to include in the report and + calculate trends for. If not provided, all evaluations will be included. synthetic_timestamp : str, optional A synthetic timestamp to use for the report. This is useful for generating back-dated reports. The default is None, which uses the @@ -1089,10 +1092,8 @@ def export( # write to file if synthetic_timestamp is not None: - today = synthetic_timestamp today_now = synthetic_timestamp else: - today = dt_date.today().strftime("%Y-%m-%d") today_now = dt_datetime.now().strftime("%Y-%m-%d %H:%M:%S") current_report_metrics: List[List[PerformanceMetric]] = [] @@ -1102,9 +1103,7 @@ def export( report_paths = glob.glob( os.path.join( self.output_dir, - "cyclops_reports", - "*", - "*", + "cyclops_report", "*.json", ), ) @@ -1135,6 +1134,10 @@ def export( metric_cards, ) + if self._model_card.overview is not None: + last_n_evals = 0 if last_n_evals is None else last_n_evals + self._model_card.overview.last_n_evals = last_n_evals + self._validate() template = self._get_jinja_template(template_path=template_path) @@ -1143,7 +1146,6 @@ def export( "sweep_graphics": sweep_graphics, "get_slices": get_slices, "get_thresholds": get_thresholds, - "get_trends": get_trends, "get_passed": get_passed, "get_names": get_names, "get_histories": get_histories, @@ -1154,12 +1156,9 @@ def export( plotlyjs = get_plotlyjs() if interactive else None content = template.render(model_card=self._model_card, plotlyjs=plotlyjs) - now = dt_datetime.now().strftime("%H-%M-%S") report_path = os.path.join( self.output_dir, - "cyclops_reports", - today, - now, + "cyclops_report", output_filename or "model_card.html", ) self._write_file(report_path, content) diff --git a/cyclops/report/templates/model_report/macros.jinja b/cyclops/report/templates/model_report/macros.jinja index ce462b350..3b22d1ae8 100644 --- a/cyclops/report/templates/model_report/macros.jinja +++ b/cyclops/report/templates/model_report/macros.jinja @@ -172,8 +172,14 @@ {% macro render_perf(name, comp)%}
-

How is your model doing?


-

A quick glance of your most important metrics.

+
+

How is your model doing?


+

A quick glance of your most important metrics.

+
+ +

Last {{ comp.last_n_evals }} Evaluations

+ +
{% for metric_card in comp.metric_cards.collection%} {% if metric_card.slice == 'overall' %} {{ render_metric_card(metric_card, loop.index-1, "subcard_overview") }} @@ -185,15 +191,23 @@ {% macro render_perf_over_time(name, comp)%}
-

How is your model doing over time?


-

See how your model is performing over several metrics and subgroups over time.

-
-

Multi-plot Selection:

-
- - - - + +

Last {{ comp.last_n_evals }} Evaluations

+ +
+
+
+

How is your model doing over time?


+

See how your model is performing over several metrics and subgroups over time.

+
+
+

Multi-plot Selection:

+
+ + + + +
diff --git a/cyclops/report/templates/model_report/model_report.jinja b/cyclops/report/templates/model_report/model_report.jinja index f4d5bb157..15d22218b 100644 --- a/cyclops/report/templates/model_report/model_report.jinja +++ b/cyclops/report/templates/model_report/model_report.jinja @@ -531,8 +531,53 @@ "rgb(23, 190, 207)" ]; + // create global variable for max_n_evals + var histories = JSON.parse({{ get_histories(model_card)|safe|tojson }}); + // get max_n_evals from histories + var history_data = []; + for (let i = 0; i < histories[0].length; i++) { + history_data.push(parseFloat(histories[0][i])); + } + var max_n_evals = history_data.length; + +// Add event listeners to radio buttons + for (let input of inputs_all) { + input.addEventListener('change', updatePlot); + } + // Add event listener to update plot when window is resized + window.addEventListener('resize', updatePlot); + for (let selection of plot_selection) { + selection.addEventListener('change', updatePlotSelection); + } + // Initial update when the page loads updatePlot(); document.addEventListener('DOMContentLoaded', setCollapseButton); + + function updateLastNEvals() { + var n_evals_slider_p = document.getElementById("n_evals_slider_p"); + var slider_p_num = document.getElementById("slider_p_num"); + var n_evals_slider_pot = document.getElementById("n_evals_slider_pot"); + var slider_pot_num = document.getElementById("slider_pot_num"); + + n_evals_slider_p.max = max_n_evals; + n_evals_slider_pot.max = max_n_evals; + + if (n_evals_slider_p !== null) { + n_evals_slider_p.oninput = function() { + last_n_evals = this.value; + slider_p_num.innerHTML = last_n_evals; + generate_model_card_plot(); + } + } + if (n_evals_slider_pot !== null) { + n_evals_slider_pot.oninput = function() { + last_n_evals = this.value; + slider_pot_num.innerHTML = last_n_evals; + updatePlot(); + } + } + } + document.addEventListener('DOMContentLoaded', updateLastNEvals); diff --git a/cyclops/report/templates/model_report/plot.js b/cyclops/report/templates/model_report/plot.js index 7a0bc6dcb..7934c00de 100644 --- a/cyclops/report/templates/model_report/plot.js +++ b/cyclops/report/templates/model_report/plot.js @@ -15,228 +15,232 @@ function updatePlot() { var inputs_name = []; var inputs_value = []; for (let i = 0; i < inputs.length; i++) { - inputs_name.push(inputs[i].name); - inputs_value.push(inputs[i].value); + inputs_name.push(inputs[i].name); + inputs_value.push(inputs[i].value); } var plot_number = parseInt(plot_selected.value.split(" ")[1]-1); var selection = []; for (let i = 0; i < inputs_value.length; i++) { - selection.push(inputs_name[i] + ":" + inputs_value[i]); + selection.push(inputs_name[i] + ":" + inputs_value[i]); } selection.sort(); selections[plot_number] = selection; // if plot_selected is "+" then add new radio button to plot_selection called "Plot N" where last plot is N-1 but keep "+" at end and set new radio button to checked for second last element if (plot_selected.value === "+") { - // if 10 plots already exist, don't add new plot and gray out "+" - if (plot_selection.length === 13) { + // if 10 plots already exist, don't add new plot and gray out "+" + if (plot_selection.length === 13) { plot_selected.checked = false; label_selection[-1].style.color = "gray"; return; - } - var new_plot = document.createElement("input"); - new_plot.type = "radio"; - new_plot.id = "Plot " + (plot_selection.length); - new_plot.name = "plot"; - new_plot.value = "Plot " + (plot_selection.length); - new_plot.checked = true; - var new_label = document.createElement("label"); - new_label.htmlFor = "Plot " + (plot_selection.length); - new_label.innerHTML = "Plot " + (plot_selection.length); - - // Parse plot_color to get r, g, b values - var plot_color = plot_colors[plot_selection.length] - const [r, g, b] = plot_color.match(/\d+/g); - const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; - // set background color of new radio button to plot_color - new_label.style.backgroundColor = rgbaColor; - new_label.style.border = "2px solid " + plot_color; - new_label.style.color = plot_color; - - // insert new radio button and label before "+" radio button and after last radio button - plot_selected.insertAdjacentElement("beforebegin", new_plot); - plot_selected.insertAdjacentElement("beforebegin", new_label); - // Add event listener to new radio button - new_plot.addEventListener('change', updatePlot); - - // set plot_selected to new plot - plot_selected = new_plot - - for (let i = 0; i < label_selection.length-1; i++) { + } + var new_plot = document.createElement("input"); + new_plot.type = "radio"; + new_plot.id = "Plot " + (plot_selection.length); + new_plot.name = "plot"; + new_plot.value = "Plot " + (plot_selection.length); + new_plot.checked = true; + var new_label = document.createElement("label"); + new_label.htmlFor = "Plot " + (plot_selection.length); + new_label.innerHTML = "Plot " + (plot_selection.length); + + // Parse plot_color to get r, g, b values + var plot_color = plot_colors[plot_selection.length] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + // set background color of new radio button to plot_color + new_label.style.backgroundColor = rgbaColor; + new_label.style.border = "2px solid " + plot_color; + new_label.style.color = plot_color; + + // insert new radio button and label before "+" radio button and after last radio button + plot_selected.insertAdjacentElement("beforebegin", new_plot); + plot_selected.insertAdjacentElement("beforebegin", new_label); + // Add event listener to new radio button + new_plot.addEventListener('change', updatePlot); + + // set plot_selected to new plot + plot_selected = new_plot + + for (let i = 0; i < label_selection.length-1; i++) { plot_selection[i].checked = false; label_selection[i].style.backgroundColor = "#ffffff"; label_selection[i].style.border = "2px solid #DADCE0"; label_selection[i].style.color = "#000000"; } } else { - for (let i = 0; i < plot_selection.length-1; i++) { + for (let i = 0; i < plot_selection.length-1; i++) { if (plot_selection[i].value !== plot_selected.value) { - plot_selection[i].checked = false; - label_selection[i].style.backgroundColor = "#ffffff"; - label_selection[i].style.border = "2px solid #DADCE0"; - label_selection[i].style.color = "#000000"; + plot_selection[i].checked = false; + label_selection[i].style.backgroundColor = "#ffffff"; + label_selection[i].style.border = "2px solid #DADCE0"; + label_selection[i].style.color = "#000000"; } else { - var plot_color = plot_colors[i+1] - const [r, g, b] = plot_color.match(/\d+/g); - const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; - plot_selected.checked = true; - label_selection[i].style.backgroundColor = rgbaColor; - label_selection[i].style.border = "2px solid " + plot_color; - label_selection[i].style.color = plot_color; + var plot_color = plot_colors[i+1] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + plot_selected.checked = true; + label_selection[i].style.backgroundColor = rgbaColor; + label_selection[i].style.border = "2px solid " + plot_color; + label_selection[i].style.color = plot_color; + } } - } } var slices_all = JSON.parse({{ get_slices(model_card)|safe|tojson }}); var histories_all = JSON.parse({{ get_histories(model_card)|safe|tojson }}); var thresholds_all = JSON.parse({{ get_thresholds(model_card)|safe|tojson }}); - var trends_all = JSON.parse({{ get_trends(model_card)|safe|tojson }}); var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }}); var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }}); var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); for (let i = 0; i < selection.length; i++) { - // use selection to set label_slice_selection background color - for (let j = 0; j < inputs_all.length; j++) { + // use selection to set label_slice_selection background color + for (let j = 0; j < inputs_all.length; j++) { if (inputs_all[j].name === selection[i].split(":")[0]) { - if (inputs_all[j].value == selection[i].split(":")[1]) { + if (inputs_all[j].value == selection[i].split(":")[1]) { inputs_all[j].checked = true; const [r, g, b] = plot_color.match(/\d+/g); const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; label_slice_selection[j].style.backgroundColor = rgbaColor; label_slice_selection[j].style.border = "2px solid " + plot_color; label_slice_selection[j].style.color = plot_color; - } - else { + } + else { inputs_all[j].checked = false; label_slice_selection[j].style.backgroundColor = "#ffffff"; label_slice_selection[j].style.border = "2px solid #DADCE0"; label_slice_selection[j].style.color = "#000000"; - } + } + } } - } } var radioGroups = {}; var labelGroups = {}; for (let i = 0; i < inputs_all.length; i++) { - var input = inputs_all[i]; - var label = label_slice_selection[i]; - var groupName = input.name; - if (!radioGroups[groupName]) { + var input = inputs_all[i]; + var label = label_slice_selection[i]; + var groupName = input.name; + if (!radioGroups[groupName]) { radioGroups[groupName] = []; labelGroups[groupName] = []; - } - radioGroups[groupName].push(input); - labelGroups[groupName].push(label); + } + radioGroups[groupName].push(input); + labelGroups[groupName].push(label); } // use radioGroups to loop through selection changing only one element at a time for (let i = 0; i < selection.length; i++) { - for (let j = 0; j < inputs_all.length; j++) { + for (let j = 0; j < inputs_all.length; j++) { if (inputs_all[j].name === selection[i].split(":")[0]) { - radio_group = radioGroups[selection[i].split(":")[0]]; - label_group = labelGroups[selection[i].split(":")[0]]; - for (let k = 0; k < radio_group.length; k++) { + radio_group = radioGroups[selection[i].split(":")[0]]; + label_group = labelGroups[selection[i].split(":")[0]]; + for (let k = 0; k < radio_group.length; k++) { selection_copy = selection.slice(); selection_copy[i] = selection[i].split(":")[0] + ":" + radio_group[k].value; // get idx of slices where all elements match var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection_copy.sort())); if (idx === undefined) { - // set radio button to disabled and cursor to not allowed and color to gray if idx is undefined - radio_group[k].disabled = true; - label_group[k].style.cursor = "not-allowed"; - label_group[k].style.color = "gray"; - label_group[k].style.backgroundColor = "rgba(125, 125, 125, 0.2)"; + // set radio button to disabled and cursor to not allowed and color to gray if idx is undefined + radio_group[k].disabled = true; + label_group[k].style.cursor = "not-allowed"; + label_group[k].style.color = "gray"; + label_group[k].style.backgroundColor = "rgba(125, 125, 125, 0.2)"; } else { - radio_group[k].disabled = false; - label_group[k].style.cursor = "pointer"; + radio_group[k].disabled = false; + label_group[k].style.cursor = "pointer"; + } } - } } - } + } } traces = []; for (let i = 0; i < selections.length; i++) { - if (selections[i] === null) { + if (selections[i] === null) { continue; - } - selection = selections[i] - // get idx of slices where all elements match - var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); - var history_data = []; - for (let i = 0; i < histories_all[idx].length; i++) { + } + selection = selections[i] + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); + var history_data = []; + for (let i = 0; i < histories_all[idx].length; i++) { history_data.push(parseFloat(histories_all[idx][i])); - } - var timestamp_data = []; - for (let i = 0; i < timestamps_all[idx].length; i++) { + } + var timestamp_data = []; + for (let i = 0; i < timestamps_all[idx].length; i++) { timestamp_data.push(timestamps_all[idx][i]); - } - threshold = parseFloat(thresholds_all[idx]); - trend = trends_all[idx]; - passed = passed_all[idx]; - name = names_all[idx]; - - // if trend is "positive" set keyword to upwards, if trend is "negative" set keyword to downwards, else set keyword to flat - if (trend === "positive") { + } + var last_n_evals = document.getElementById("n_evals_slider_pot").value; + history_data = history_data.slice(-last_n_evals); + timestamp_data = timestamp_data.slice(-last_n_evals); + // get slope of line of best fit, if >0.01 then trending up, if <0.01 then trending down, else flat + var slope = lineOfBestFit(history_data)[0]; + if (slope > 0.01) { var trend_keyword = "upwards"; - } else if (trend === "negative") { + } + else if (slope < -0.01) { var trend_keyword = "downwards"; - } else { + } + else { var trend_keyword = "flat"; - } + } - // if passed is true set keyword to Above, if passed is false set keyword to Below - if (passed) { + threshold = parseFloat(thresholds_all[idx]); + passed = passed_all[idx]; + name = names_all[idx]; + + // if passed is true set keyword to Above, if passed is false set keyword to Below + if (passed) { var passed_keyword = "above"; - } - else { + } + else { var passed_keyword = "below"; - } + } - // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. - // get number of nulls in selections, if 9 then plot title, else don't plot title - var nulls = 0; - for (let i = 0; i < selections.length; i++) { + // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. + // get number of nulls in selections, if 9 then plot title, else don't plot title + var nulls = 0; + for (let i = 0; i < selections.length; i++) { if (selections[i] === null) { - nulls += 1; + nulls += 1; } - } - if (nulls === 10) { + } + if (nulls === 10) { var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; var plot_title = multipleStringLines(plot_title); var showlegend = false; - } - else { + } + else { var plot_title = ""; var showlegend = true; - } - name = "" - suffix = " ( " - for (let i = 0; i < selection.length; i++) { + } + name = "" + suffix = " ( " + for (let i = 0; i < selection.length; i++) { if (selection[i].split(":")[0] === "metric") { - name += selection[i].split(":")[1]; + name += selection[i].split(":")[1]; } else { - if (selection[i].split(":")[1].includes("overall")) { + if (selection[i].split(":")[1].includes("overall")) { continue; - } else { + } else { suffix += selection[i]; suffix += ", "; - } + } } - } - if (suffix === " ( ") { + } + if (suffix === " ( ") { name += ""; - } - else { + } + else { suffix = suffix.slice(0, -2); name += suffix + " )"; - } - var trace = { + } + var trace = { // range of x is the length of the list of floats x: timestamp_data, y: history_data, @@ -246,12 +250,12 @@ function updatePlot() { line: {color: plot_colors[i+1]}, name: name, //name: selection.toString(), - }; - traces.push(trace); + }; + traces.push(trace); } if (nulls === 10) { - var threshold_trace = { + var threshold_trace = { x: timestamp_data, y: Array.from({length: history_data.length}, (_, i) => threshold), mode: 'lines', @@ -259,64 +263,55 @@ function updatePlot() { marker: {color: 'rgb(0,0,0)'}, line: {color: 'rgb(0,0,0)', dash: 'dot'}, name: '', - }; - traces.push(threshold_trace); + }; + traces.push(threshold_trace); } var width = Math.max(parent.innerWidth - 900, 500); var layout = { - title: { + title: { text: plot_title, font: { - family: 'Arial, Helvetica, sans-serif', - size: 18, + family: 'Arial, Helvetica, sans-serif', + size: 18, } - }, - paper_bgcolor: 'rgba(0,0,0,0)', - plot_bgcolor: 'rgba(0,0,0,0)', - xaxis: { + }, + paper_bgcolor: 'rgba(0,0,0,0)', + plot_bgcolor: 'rgba(0,0,0,0)', + xaxis: { zeroline: false, showticklabels: true, showgrid: false, tickformat: '%b\n %Y' - }, - yaxis: { + }, + yaxis: { gridcolor: '#ffffff', zeroline: false, showticklabels: true, showgrid: true, range: [-0.10, 1.10], - }, - showlegend: showlegend, - // show legend at top - legend: { + }, + showlegend: showlegend, + // show legend at top + legend: { orientation: "h", yanchor: "top", y: 1.1, xanchor: "left", x: 0.1 - }, - margin: { + }, + margin: { l: 50, r: 50, b: 50, t: 50, pad: 4 - }, - // set height and width of plot to width of card minus 500px - height: 500, - width: width, + }, + // set height and width of plot to width of card minus 500px + height: 500, + width: width, } Plotly.newPlot(plot, traces, layout, {displayModeBar: false}); -} -// Add event listeners to radio buttons -for (let input of inputs_all) { - input.addEventListener('change', updatePlot); -} -// Add event listener to update plot when window is resized -window.addEventListener('resize', updatePlot); -for (let selection of plot_selection) { - selection.addEventListener('change', updatePlotSelection); -} + } function generate_model_card_plot() { @@ -327,21 +322,24 @@ function generate_model_card_plot() { var timestamps = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); for (let i = 0; i < overall_indices.length; i++) { - var idx = overall_indices[i]; - var model_card_plot = "model-card-plot-" + idx; - var threshold = thresholds[idx]; - var history_data = []; - for (let i = 0; i < histories[idx].length; i++) { + var idx = overall_indices[i]; + var model_card_plot = "model-card-plot-" + idx; + var threshold = thresholds[idx]; + var history_data = []; + for (let i = 0; i < histories[idx].length; i++) { history_data.push(parseFloat(histories[idx][i])); - } - var timestamp_data = []; - for (let i = 0; i < timestamps[idx].length; i++) { + } + var timestamp_data = []; + for (let i = 0; i < timestamps[idx].length; i++) { timestamp_data.push(timestamps[idx][i]); - } + } + var last_n_evals = document.getElementById("n_evals_slider_p").value; + history_data = history_data.slice(-last_n_evals); + timestamp_data = timestamp_data.slice(-last_n_evals); - var model_card_fig = { + var model_card_fig = { data: [ - { + { x: timestamp_data, y: history_data, mode: "lines+markers", @@ -350,8 +348,8 @@ function generate_model_card_plot() { showlegend: false, type: "scatter", name: "" - }, - { + }, + { x: timestamp_data, y: Array(history_data.length).fill(threshold), mode: "lines", @@ -359,35 +357,35 @@ function generate_model_card_plot() { showlegend: false, type: "scatter", name: "" - } + } ], layout: { - paper_bgcolor: "rgba(0,0,0,0)", - plot_bgcolor: "rgba(0,0,0,0)", - xaxis: { + paper_bgcolor: "rgba(0,0,0,0)", + plot_bgcolor: "rgba(0,0,0,0)", + xaxis: { zeroline: false, showticklabels: true, showgrid: false, tickformat: '%b\n %Y' - }, - yaxis: { + }, + yaxis: { gridcolor: "#ffffff", zeroline: false, showticklabels: true, showgrid: true, range: [-0.10, 1.10], - }, - margin: { l: 30, r: 0, t: 0, b: 30 }, - padding: { l: 0, r: 0, t: 0, b: 0 }, - height: 150, - width: 300 - } - }; - if (history.length > 0) { - Plotly.newPlot(model_card_plot, model_card_fig.data, model_card_fig.layout, {displayModeBar: false}); + }, + margin: { l: 30, r: 0, t: 0, b: 30 }, + padding: { l: 0, r: 0, t: 0, b: 0 }, + height: 150, + width: 300 + } + }; + if (history.length > 0) { + Plotly.newPlot(model_card_plot, model_card_fig.data, model_card_fig.layout, {displayModeBar: false}); + } } } -} function updatePlotSelection() { @@ -402,143 +400,142 @@ function updatePlotSelection() { // if plot_selected is "+" then add new radio button to plot_selection called "Plot N" where last plot is N-1 but keep "+" at end and set new radio button to checked for second last element if (plot_selected.value === "+") { - // if 10 plots already exist, don't add new plot and gray out "+" - if (plot_selection.length === 11) { + // if 10 plots already exist, don't add new plot and gray out "+" + if (plot_selection.length === 11) { plot_selected.checked = false; label_selection[label_selection.length-1].style.color = "gray"; return; - } - // plot_name should be name of last plot + 1 - if (plot_selection.length === 2) { + } + // plot_name should be name of last plot + 1 + if (plot_selection.length === 2) { var plot_name = "Plot 2" - } else { + } else { var plot_name = "Plot " + (parseInt(plot_selection[plot_selection.length - 2].value.split(" ")[1]) + 1); - } - var new_plot = document.createElement("input"); - new_plot.type = "radio"; - new_plot.id = plot_name; - new_plot.name = "plot"; - new_plot.value = plot_name; - new_plot.checked = true; - var new_label = document.createElement("label"); - new_label.htmlFor = plot_name; - new_label.innerHTML = plot_name; - - // Parse plot_color to get r, g, b values - var plot_color = plot_colors[plot_selection.length] - const [r, g, b] = plot_color.match(/\d+/g); - const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; - // set background color of new radio button to plot_color - new_label.style.backgroundColor = rgbaColor; - new_label.style.border = "2px solid " + plot_color; - new_label.style.color = plot_color; - - // add button to delete plot - var delete_button = document.createElement("button"); - delete_button.id = "button"; - delete_button.innerHTML = "×"; - delete_button.style.backgroundColor = "transparent"; - delete_button.style.border = "none"; - new_label.style.padding = "1.5px 0px"; - new_label.style.paddingLeft = "10px"; - - new_label.appendChild(delete_button) - - // make delete button from last plot invisible if not Plot 1 - if (plot_selection.length > 2) { + } + var new_plot = document.createElement("input"); + new_plot.type = "radio"; + new_plot.id = plot_name; + new_plot.name = "plot"; + new_plot.value = plot_name; + new_plot.checked = true; + var new_label = document.createElement("label"); + new_label.htmlFor = plot_name; + new_label.innerHTML = plot_name; + + // Parse plot_color to get r, g, b values + var plot_color = plot_colors[plot_selection.length] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + // set background color of new radio button to plot_color + new_label.style.backgroundColor = rgbaColor; + new_label.style.border = "2px solid " + plot_color; + new_label.style.color = plot_color; + + // add button to delete plot + var delete_button = document.createElement("button"); + delete_button.id = "button"; + delete_button.innerHTML = "×"; + delete_button.style.backgroundColor = "transparent"; + delete_button.style.border = "none"; + new_label.style.padding = "1.5px 0px"; + new_label.style.paddingLeft = "10px"; + + new_label.appendChild(delete_button) + + // make delete button from last plot invisible if not Plot 1 + if (plot_selection.length > 2) { button_plot_selection[button_plot_selection.length-1].style.visibility = "hidden"; - } - // add on_click event to delete button and send plot number to deletePlotSelection - delete_button.onclick = function() {deletePlotSelection(plot_number)}; + } + // add on_click event to delete button and send plot number to deletePlotSelection + delete_button.onclick = function() {deletePlotSelection(plot_number)}; - // insert new radio button and label before "+" radio button and after last radio button - plot_selected.insertAdjacentElement("beforebegin", new_plot); - plot_selected.insertAdjacentElement("beforebegin", new_label); + // insert new radio button and label before "+" radio button and after last radio button + plot_selected.insertAdjacentElement("beforebegin", new_plot); + plot_selected.insertAdjacentElement("beforebegin", new_label); - // Add event listener to new radio button - new_plot.addEventListener('change', updatePlotSelection); + // Add event listener to new radio button + new_plot.addEventListener('change', updatePlotSelection); - // set plot_selected to new plot - var plot_selected = new_plot + // set plot_selected to new plot + var plot_selected = new_plot - for (let i = 0; i < label_selection.length-1; i++) { + for (let i = 0; i < label_selection.length-1; i++) { plot_selection[i].checked = false; label_selection[i].style.backgroundColor = "#ffffff"; label_selection[i].style.border = "2px solid #DADCE0"; label_selection[i].style.color = "#000000"; } - selections[parseInt(plot_selected.value.split(" ")[1]-1)] = selections[parseInt(plot_selected.value.split(" ")[1]-2)] - selection = selections[parseInt(plot_selected.value.split(" ")[1]-1)]; - plot_color = plot_colors[parseInt(plot_selected.value.split(" ")[1])]; + selections[parseInt(plot_selected.value.split(" ")[1]-1)] = selections[parseInt(plot_selected.value.split(" ")[1]-2)] + selection = selections[parseInt(plot_selected.value.split(" ")[1]-1)]; + plot_color = plot_colors[parseInt(plot_selected.value.split(" ")[1])]; - for (let i = 0; i < selection.length; i++) { + for (let i = 0; i < selection.length; i++) { // use selection to set label_slice_selection background color for (let j = 0; j < inputs_all.length; j++) { - if (inputs_all[j].name === selection[i].split(":")[0]) { + if (inputs_all[j].name === selection[i].split(":")[0]) { if (inputs_all[j].value == selection[i].split(":")[1]) { - const [r, g, b] = plot_color.match(/\d+/g); - const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; - inputs_all[j].checked = true; - label_slice_selection[j].style.backgroundColor = rgbaColor; - label_slice_selection[j].style.border = "2px solid " + plot_color; - label_slice_selection[j].style.color = plot_color; + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + inputs_all[j].checked = true; + label_slice_selection[j].style.backgroundColor = rgbaColor; + label_slice_selection[j].style.border = "2px solid " + plot_color; + label_slice_selection[j].style.color = plot_color; } else { - inputs_all[j].checked = false; - label_slice_selection[j].style.backgroundColor = "#ffffff"; - label_slice_selection[j].style.border = "2px solid #DADCE0"; - label_slice_selection[j].style.color = "#000000"; + inputs_all[j].checked = false; + label_slice_selection[j].style.backgroundColor = "#ffffff"; + label_slice_selection[j].style.border = "2px solid #DADCE0"; + label_slice_selection[j].style.color = "#000000"; + } } - } } - } + } } else { - for (let i = 0; i < plot_selection.length-1; i++) { + for (let i = 0; i < plot_selection.length-1; i++) { if (plot_selection[i].value !== plot_selected.value) { - plot_selection[i].checked = false; - label_selection[i].style.backgroundColor = "#ffffff"; - label_selection[i].style.border = "2px solid #DADCE0"; - label_selection[i].style.color = "#000000"; + plot_selection[i].checked = false; + label_selection[i].style.backgroundColor = "#ffffff"; + label_selection[i].style.border = "2px solid #DADCE0"; + label_selection[i].style.color = "#000000"; } else { - var plot_color = plot_colors[i+1] - const [r, g, b] = plot_color.match(/\d+/g); - const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; - plot_selected.checked = true; - label_selection[i].style.backgroundColor = rgbaColor; - label_selection[i].style.border = "2px solid " + plot_color; - label_selection[i].style.color = plot_color; + var plot_color = plot_colors[i+1] + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + plot_selected.checked = true; + label_selection[i].style.backgroundColor = rgbaColor; + label_selection[i].style.border = "2px solid " + plot_color; + label_selection[i].style.color = plot_color; } - } - selection = selections[parseInt(plot_selected.value.split(" ")[1]-1)]; - plot_color = plot_colors[parseInt(plot_selected.value.split(" ")[1])]; - for (let i = 0; i < selection.length; i++) { + } + selection = selections[parseInt(plot_selected.value.split(" ")[1]-1)]; + plot_color = plot_colors[parseInt(plot_selected.value.split(" ")[1])]; + for (let i = 0; i < selection.length; i++) { // use selection to set label_slice_selection background color for (let j = 0; j < inputs_all.length; j++) { - if (inputs_all[j].name === selection[i].split(":")[0]) { + if (inputs_all[j].name === selection[i].split(":")[0]) { if (inputs_all[j].value == selection[i].split(":")[1]) { - inputs_all[j].checked = true; - const [r, g, b] = plot_color.match(/\d+/g); - const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; - label_slice_selection[j].style.backgroundColor = rgbaColor; - label_slice_selection[j].style.border = "2px solid " + plot_color; - label_slice_selection[j].style.color = plot_color; + inputs_all[j].checked = true; + const [r, g, b] = plot_color.match(/\d+/g); + const rgbaColor = `rgba(${r}, ${g}, ${b}, 0.2)`; + label_slice_selection[j].style.backgroundColor = rgbaColor; + label_slice_selection[j].style.border = "2px solid " + plot_color; + label_slice_selection[j].style.color = plot_color; } else { - inputs_all[j].checked = false; - label_slice_selection[j].style.backgroundColor = "#ffffff"; - label_slice_selection[j].style.border = "2px solid #DADCE0"; - label_slice_selection[j].style.color = "#000000"; + inputs_all[j].checked = false; + label_slice_selection[j].style.backgroundColor = "#ffffff"; + label_slice_selection[j].style.border = "2px solid #DADCE0"; + label_slice_selection[j].style.color = "#000000"; + } } - } } - } + } } var slices_all = JSON.parse({{ get_slices(model_card)|safe|tojson }}); var histories_all = JSON.parse({{ get_histories(model_card)|safe|tojson }}); var thresholds_all = JSON.parse({{ get_thresholds(model_card)|safe|tojson }}); - var trends_all = JSON.parse({{ get_trends(model_card)|safe|tojson }}); var passed_all = JSON.parse({{ get_passed(model_card)|safe|tojson }}); var names_all = JSON.parse({{ get_names(model_card)|safe|tojson }}); var timestamps_all = JSON.parse({{ get_timestamps(model_card)|safe|tojson }}); @@ -546,125 +543,130 @@ function updatePlotSelection() { var radioGroups = {}; var labelGroups = {}; for (let i = 0; i < inputs_all.length; i++) { - var input = inputs_all[i]; - var label = label_slice_selection[i]; - var groupName = input.name; - if (!radioGroups[groupName]) { + var input = inputs_all[i]; + var label = label_slice_selection[i]; + var groupName = input.name; + if (!radioGroups[groupName]) { radioGroups[groupName] = []; labelGroups[groupName] = []; - } - radioGroups[groupName].push(input); - labelGroups[groupName].push(label); + } + radioGroups[groupName].push(input); + labelGroups[groupName].push(label); } // use radioGroups to loop through selection changing only one element at a time for (let i = 0; i < selection.length; i++) { - for (let j = 0; j < inputs_all.length; j++) { + for (let j = 0; j < inputs_all.length; j++) { if (inputs_all[j].name === selection[i].split(":")[0]) { - radio_group = radioGroups[selection[i].split(":")[0]]; - label_group = labelGroups[selection[i].split(":")[0]]; - for (let k = 0; k < radio_group.length; k++) { + radio_group = radioGroups[selection[i].split(":")[0]]; + label_group = labelGroups[selection[i].split(":")[0]]; + for (let k = 0; k < radio_group.length; k++) { selection_copy = selection.slice(); selection_copy[i] = selection[i].split(":")[0] + ":" + radio_group[k].value; // get idx of slices where all elements match var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection_copy.sort())); if (idx === undefined) { - // set radio button to disabled and cursor to not allowed and color to gray if idx is undefined - radio_group[k].disabled = true; - label_group[k].style.cursor = "not-allowed"; - label_group[k].style.color = "gray"; - label_group[k].style.backgroundColor = "rgba(125, 125, 125, 0.2)"; + // set radio button to disabled and cursor to not allowed and color to gray if idx is undefined + radio_group[k].disabled = true; + label_group[k].style.cursor = "not-allowed"; + label_group[k].style.color = "gray"; + label_group[k].style.backgroundColor = "rgba(125, 125, 125, 0.2)"; } else { - radio_group[k].disabled = false; - label_group[k].style.cursor = "pointer"; + radio_group[k].disabled = false; + label_group[k].style.cursor = "pointer"; + } } - } } - } + } } traces = []; var plot_number = parseInt(plot_selected.value.split(" ")[1]-1); for (let i = 0; i < selections.length; i++) { - if (selections[i] === null) { + if (selections[i] === null) { continue; - } - selection = selections[i] + } + selection = selections[i] - // get idx of slices where all elements match - var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); - var history_data = []; - for (let i = 0; i < histories_all[idx].length; i++) { + // get idx of slices where all elements match + var idx = Object.keys(slices_all).find(key => JSON.stringify(slices_all[key].sort()) === JSON.stringify(selection)); + var history_data = []; + for (let i = 0; i < histories_all[idx].length; i++) { history_data.push(parseFloat(histories_all[idx][i])); - } - var timestamp_data = []; - for (let i = 0; i < timestamps_all[idx].length; i++) { + } + var timestamp_data = []; + for (let i = 0; i < timestamps_all[idx].length; i++) { timestamp_data.push(timestamps_all[idx][i]); - } - threshold = parseFloat(thresholds_all[idx]); - trend = trends_all[idx]; - passed = passed_all[idx]; - name = names_all[idx]; - - // if trend is "positive" set keyword to upwards, if trend is "negative" set keyword to downwards, else set keyword to flat - if (trend === "positive") { + } + var last_n_evals = document.getElementById("n_evals_slider_pot").value; + history_data = history_data.slice(-last_n_evals); + timestamp_data = timestamp_data.slice(-last_n_evals); + // get slope of line of best fit, if >0.01 then trending up, if <0.01 then trending down, else flat + var slope = lineOfBestFit(history_data)[0]; + if (slope > 0.01) { var trend_keyword = "upwards"; - } else if (trend === "negative") { + } + else if (slope < -0.01) { var trend_keyword = "downwards"; - } else { + } + else { var trend_keyword = "flat"; - } + } - // if passed is true set keyword to Above, if passed is false set keyword to Below - if (passed) { + threshold = parseFloat(thresholds_all[idx]); + passed = passed_all[idx]; + name = names_all[idx]; + + // if passed is true set keyword to Above, if passed is false set keyword to Below + if (passed) { var passed_keyword = "above"; - } - else { + } + else { var passed_keyword = "below"; - } + } - // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. - // get number of nulls in selections, if 9 then plot title, else don't plot title - var nulls = 0; - for (let i = 0; i < selections.length; i++) { + // create title for plot: Current {metric name} is trending {trend_keyword} and is {passed_keyword} the threshold. + // get number of nulls in selections, if 9 then plot title, else don't plot title + var nulls = 0; + for (let i = 0; i < selections.length; i++) { if (selections[i] === null) { - nulls += 1; + nulls += 1; } - } - if (nulls === 10) { - var plot_title = "Current " + name + " is trending " + trend_keyword + " and is " + passed_keyword + " the threshold."; + } + if (nulls === 10) { + var plot_title = "Current " + name + " is trending " + "flat" + " and is " + passed_keyword + " the threshold."; var plot_title = multipleStringLines(plot_title); var showlegend = false; - } - else { + } + else { var plot_title = ""; var showlegend = true; - } - name = "" - suffix = " ( " - for (let i = 0; i < selection.length; i++) { + } + name = "" + suffix = " ( " + for (let i = 0; i < selection.length; i++) { if (selection[i].split(":")[0] === "metric") { - name += selection[i].split(":")[1]; + name += selection[i].split(":")[1]; } else { - if (selection[i].split(":")[1].includes("overall")) { + if (selection[i].split(":")[1].includes("overall")) { continue; - } else { + } else { suffix += selection[i]; suffix += ", "; - } + } } - } - if (suffix === " ( ") { + } + if (suffix === " ( ") { name += ""; - } - else { + } + else { suffix = suffix.slice(0, -2); name += suffix + " )"; - } + } - var trace = { + var trace = { // range of x is the length of the list of floats x: timestamp_data, y: history_data, @@ -673,12 +675,12 @@ function updatePlotSelection() { marker: {color: plot_colors[i+1]}, line: {color: plot_colors[i+1]}, name: name, - }; - traces.push(trace); + }; + traces.push(trace); } if (nulls === 10) { - var threshold_trace = { + var threshold_trace = { x: timestamp_data, y: Array.from({length: history_data.length}, (_, i) => threshold), mode: 'lines', @@ -686,53 +688,53 @@ function updatePlotSelection() { marker: {color: 'rgb(0,0,0)'}, line: {color: 'rgb(0,0,0)', dash: 'dot'}, name: '', - }; - traces.push(threshold_trace); + }; + traces.push(threshold_trace); } var width = Math.max(parent.innerWidth - 900, 500); var layout = { - title: { + title: { text: plot_title, font: { - family: 'Arial, Helvetica, sans-serif', - size: 18, + family: 'Arial, Helvetica, sans-serif', + size: 18, } - }, - paper_bgcolor: 'rgba(0,0,0,0)', - plot_bgcolor: 'rgba(0,0,0,0)', - xaxis: { + }, + paper_bgcolor: 'rgba(0,0,0,0)', + plot_bgcolor: 'rgba(0,0,0,0)', + xaxis: { zeroline: false, showticklabels: true, showgrid: false, tickformat: '%b\n %Y' - }, - yaxis: { + }, + yaxis: { gridcolor: '#ffffff', zeroline: false, showticklabels: true, showgrid: true, range: [-0.10, 1.10], - }, - showlegend: showlegend, - // show legend at top - legend: { + }, + showlegend: showlegend, + // show legend at top + legend: { orientation: "h", yanchor: "top", y: 1.1, xanchor: "left", x: 0.1 - }, - margin: { + }, + margin: { l: 50, r: 50, b: 50, t: 50, pad: 4 - }, - // set height and width of plot to extra-wide to fit the plot - height: 500, - // get size of window and set width of plot to size of window - width: width, + }, + // set height and width of plot to extra-wide to fit the plot + height: 500, + // get size of window and set width of plot to size of window + width: width, } Plotly.newPlot(plot, traces, layout, {displayModeBar: false}); } @@ -807,3 +809,21 @@ function refreshPlotlyPlots() { } } } + +function lineOfBestFit(y) { + var x = Array.from({length: y.length}, (_, i) => i); + var n = x.length; + var x_sum = 0; + var y_sum = 0; + var xy_sum = 0; + var xx_sum = 0; + for (var i = 0; i < n; i++) { + x_sum += x[i]; + y_sum += y[i]; + xy_sum += x[i] * y[i]; + xx_sum += x[i] * x[i]; + } + var m = (n * xy_sum - x_sum * y_sum) / (n * xx_sum - x_sum * x_sum); + var b = (y_sum - m * x_sum) / n; + return [m, b]; + } diff --git a/cyclops/report/utils.py b/cyclops/report/utils.py index 5c8c8bbac..41254f8f5 100644 --- a/cyclops/report/utils.py +++ b/cyclops/report/utils.py @@ -516,21 +516,6 @@ def get_thresholds(model_card: ModelCard) -> str: return json.dumps(thresholds) -def get_trends(model_card: ModelCard) -> str: - """Get all trends from a model card.""" - trends: Dict[int, Optional[str]] = {} - if ( - (model_card.overview is None) - or (model_card.overview.metric_cards is None) - or (model_card.overview.metric_cards.collection is None) - ): - pass - else: - for itr, metric_card in enumerate(model_card.overview.metric_cards.collection): - trends[itr] = metric_card.trend - return json.dumps(trends) - - def get_passed(model_card: ModelCard) -> str: """Get all passed from a model card.""" passed: Dict[int, Optional[bool]] = {} @@ -743,13 +728,6 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 timestamps = metric["last_metric_card"].timestamps if timestamps is not None: timestamps.append(timestamp) - (m, _) = np.polyfit(range(len(history)), history, deg=1) - if m >= 0.01: - trend = "positive" - elif m <= -0.01: - trend = "negative" - else: - trend = "neutral" metric_cards.append( MetricCard( @@ -784,9 +762,6 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 ) else None, history=history, - trend=trend - if isinstance(metric["current_metric"], PerformanceMetric) - else None, timestamps=timestamps, ), ) @@ -846,7 +821,6 @@ def create_metric_cards( # noqa: PLR0912 PLR0915 and isinstance(metric["current_metric"].value, float) else 0, ], - trend="neutral", timestamps=[timestamp], ), ) diff --git a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb index bd62e7929..8d461cedd 100644 --- a/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb +++ b/docs/source/tutorials/kaggle/heart_failure_prediction.ipynb @@ -32,6 +32,7 @@ "from datetime import date\n", "\n", "import numpy as np\n", + "import pandas as pd\n", "import plotly.express as px\n", "from datasets import Dataset\n", "from datasets.features import ClassLabel\n", @@ -40,6 +41,7 @@ "from sklearn.impute import SimpleImputer\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import MinMaxScaler, OneHotEncoder\n", + "from tqdm import tqdm\n", "\n", "from cyclops.data.df.feature import TabularFeatures\n", "from cyclops.data.slicer import SliceSpec\n", @@ -1283,23 +1285,23 @@ }, "outputs": [], "source": [ - "synthetic_timestamps = [\n", - " \"2021-09-01\",\n", - " \"2021-10-01\",\n", - " \"2021-11-01\",\n", - " \"2021-12-01\",\n", - " \"2022-01-01\",\n", - " \"2022-02-01\",\n", - "]\n", + "np.random.seed(42)\n", + "\n", + "synthetic_timestamps = pd.date_range(\n", + " start=\"1/1/2020\", periods=10, freq=\"D\"\n", + ").values.astype(str)\n", + "\n", + "\n", "report._model_card.overview = None\n", "report_path = report.export(\n", " output_filename=\"heart_failure_report_periodic.html\",\n", " synthetic_timestamp=synthetic_timestamps[0],\n", + " last_n_evals=3,\n", ")\n", "\n", "shutil.copy(f\"{report_path}\", \".\")\n", "metric_save = None\n", - "for i in range(1, 5):\n", + "for i in tqdm(range(len(synthetic_timestamps[1:]))):\n", " if i == 3:\n", " report._model_card.quantitative_analysis.performance_metrics.append(\n", " metric_save,\n", @@ -1323,9 +1325,10 @@ " report_path = report.export(\n", " output_filename=\"heart_failure_report_periodic.html\",\n", " synthetic_timestamp=synthetic_timestamps[i + 1],\n", + " last_n_evals=3,\n", " )\n", " shutil.copy(f\"{report_path}\", \".\")\n", - "shutil.rmtree(\"./cyclops_reports\")" + "shutil.rmtree(\"./cyclops_report\")" ] }, { diff --git a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb index b480f2dfd..8a1ed345c 100644 --- a/docs/source/tutorials/mimiciv/mortality_prediction.ipynb +++ b/docs/source/tutorials/mimiciv/mortality_prediction.ipynb @@ -1286,7 +1286,7 @@ " synthetic_timestamp=synthetic_timestamps[i + 1],\n", " )\n", " shutil.copy(f\"{report_path}\", \".\")\n", - "shutil.rmtree(\"./cyclops_reports\")" + "shutil.rmtree(\"./cyclops_report\")" ] }, { diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb index 9f32ddaf7..0ab608036 100644 --- a/docs/source/tutorials/synthea/los_prediction.ipynb +++ b/docs/source/tutorials/synthea/los_prediction.ipynb @@ -1470,7 +1470,7 @@ " synthetic_timestamp=synthetic_timestamps[i + 1],\n", " )\n", " shutil.copy(f\"{report_path}\", \".\")\n", - "shutil.rmtree(\"./cyclops_reports\")" + "shutil.rmtree(\"./cyclops_report\")" ] }, { diff --git a/tests/cyclops/report/test_utils.py b/tests/cyclops/report/test_utils.py index 6de671ac8..3e68f0888 100644 --- a/tests/cyclops/report/test_utils.py +++ b/tests/cyclops/report/test_utils.py @@ -33,7 +33,6 @@ get_slices, get_thresholds, get_timestamps, - get_trends, sweep_graphics, sweep_metric_cards, sweep_metrics, @@ -362,14 +361,6 @@ def test_get_thresholds(model_card): assert len(thresholds_dict.values()) == 2 -def test_get_trends(model_card): - """Test get_trends function.""" - trends = get_trends(model_card) - # read trends from json to dict - trends_dict = json.loads(trends) - assert len(trends_dict.values()) == 2 - - def test_get_passed(model_card): """Test get_passed function.""" passed = get_passed(model_card)