Skip to content

Commit

Permalink
Add theming for plotly figures (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hans Kallekleiv authored Jan 8, 2020
1 parent b6244d9 commit 148500f
Show file tree
Hide file tree
Showing 10 changed files with 171 additions and 148 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"matplotlib~=3.0",
"pillow~=6.1",
"xtgeo~=2.1",
"webviz-config>=0.0.35",
"webviz-config>=0.0.41",
# webviz-subsurface-components is part of the webviz-subsurface project,
# just located in a separate repository for convenience,
# and is therefore pinned exactly here:
Expand Down
6 changes: 5 additions & 1 deletion tests/integration_tests/test_parameter_corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dash
import pandas as pd
from webviz_config.common_cache import CACHE
from webviz_config.themes import default_theme
from webviz_config.plugins import ParameterCorrelation

# mocked functions
Expand All @@ -15,7 +16,10 @@ def test_parameter_corr(dash_duo):
app.scripts.config.serve_locally = True
app.config.suppress_callback_exceptions = True
CACHE.init_app(app.server)
app.webviz_settings = {"shared_settings": {"scratch_ensembles": {"iter-0": ""}}}
app.webviz_settings = {
"shared_settings": {"scratch_ensembles": {"iter-0": ""}},
"theme": default_theme,
}
ensembles = ["iter-0"]

with mock.patch(get_parameters) as mock_parameters:
Expand Down
99 changes: 57 additions & 42 deletions webviz_subsurface/_private_plugins/tornado_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
)
self.allow_click = allow_click
self.uid = uuid4()
self.plotly_theme = app.webviz_settings["theme"].plotly_theme
self.set_callbacks(app)

def ids(self, element):
Expand Down Expand Up @@ -212,7 +213,14 @@ def _calc_tornado(reference, scale, cutbyref, data):
self.realizations["ENSEMBLE"] == data["ENSEMBLE"]
]
try:
return tornado_plot(realizations, values, reference, scale, cutbyref)
return tornado_plot(
realizations,
values,
plotly_theme=self.plotly_theme,
reference=reference,
scale=scale,
cutbyref=cutbyref,
)
except KeyError:
return {}

Expand Down Expand Up @@ -288,7 +296,12 @@ def cut_by_ref(tornadotable, refname):

@CACHE.memoize(timeout=CACHE.TIMEOUT)
def tornado_plot(
realizations, data, reference="rms_seed", scale="Percentage", cutbyref=True
realizations,
data,
plotly_theme,
reference="rms_seed",
scale="Percentage",
cutbyref=True,
): # pylint: disable=too-many-locals

# Raise key error if no senscases, i.e. the ensemble has no design matrix
Expand Down Expand Up @@ -394,44 +407,44 @@ def tornado_plot(
df = sort_by_max(df)
# Return tornado data as Plotly figure

return {
"data": [
dict(
type="bar",
y=df["sensname"],
x=df["low"],
name="low",
customdata=df["low_reals"],
hovertext=[
f"Case: {label}<br>True Value: {val:.2f}<br>Realizations:"
f"{min(reals) if reals else None}-{max(reals) if reals else None}"
for label, val, reals in zip(
df["low_label"], df["true_low"], df["low_reals"]
)
],
hoverinfo="x+text",
orientation="h",
marker=dict(color="rgb(235, 0, 54)"),
),
dict(
type="bar",
y=df["sensname"],
x=df["high"],
name="high",
customdata=df["high_reals"],
hovertext=[
f"Case: {label}<br>True Value: {val:.2f}<br>Realizations:"
f"{min(reals) if reals else None}-{max(reals) if reals else None}"
for label, val, reals in zip(
df["high_label"], df["true_high"], df["high_reals"]
)
],
hoverinfo="x+text",
orientation="h",
marker=dict(color="rgb(36, 55, 70)"),
),
],
"layout": {
plot_data = [
dict(
type="bar",
y=df["sensname"],
x=df["low"],
name="low",
customdata=df["low_reals"],
hovertext=[
f"Case: {label}<br>True Value: {val:.2f}<br>Realizations:"
f"{min(reals) if reals else None}-{max(reals) if reals else None}"
for label, val, reals in zip(
df["low_label"], df["true_low"], df["low_reals"]
)
],
hoverinfo="x+text",
orientation="h",
),
dict(
type="bar",
y=df["sensname"],
x=df["high"],
name="high",
customdata=df["high_reals"],
hovertext=[
f"Case: {label}<br>True Value: {val:.2f}<br>Realizations:"
f"{min(reals) if reals else None}-{max(reals) if reals else None}"
for label, val, reals in zip(
df["high_label"], df["true_high"], df["high_reals"]
)
],
hoverinfo="x+text",
orientation="h",
),
]
layout = {}
layout.update(plotly_theme["layout"])
layout.update(
{
"barmode": "relative",
"margin": {"l": 50, "r": 50, "b": 20, "t": 50},
"xaxis": {
Expand All @@ -448,6 +461,7 @@ def tornado_plot(
"zeroline": False,
"showline": False,
"automargin": True,
"title": None,
},
"showlegend": False,
"annotations": [
Expand All @@ -467,5 +481,6 @@ def tornado_plot(
"ay": -25,
}
],
},
}
}
)
return {"data": plot_data, "layout": layout}
46 changes: 26 additions & 20 deletions webviz_subsurface/plugins/_inplace_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(
super().__init__()

self.csvfile = csvfile if csvfile else None
self.colorway = app.webviz_settings.get("plotly_layout", {}).get("colorway", [])
if csvfile and ensembles:
raise ValueError(
'Incorrent arguments. Either provide a "csvfile" or "ensembles" and "volfiles"'
Expand All @@ -109,7 +108,7 @@ def __init__(
self.initial_response = response
self.uid = uuid4()
self.selectors_id = {x: str(uuid4()) for x in self.selectors}

self.plotly_theme = app.webviz_settings["theme"].plotly_theme
self.set_callbacks(app)

def ids(self, element):
Expand Down Expand Up @@ -426,7 +425,7 @@ def _render_vol_chart(*args):
return (
{
"data": plot_traces,
"layout": plot_layout(plot_type, response, colors=self.colorway),
"layout": plot_layout(plot_type, response, theme=self.plotly_theme),
},
table,
)
Expand Down Expand Up @@ -506,26 +505,33 @@ def plot_table(dframe, response, name):


@CACHE.memoize(timeout=CACHE.TIMEOUT)
def plot_layout(plot_type, response, colors):
def plot_layout(plot_type, response, theme):
layout = {}
layout.update(theme["layout"])
layout["height"] = 400
if plot_type == "Histogram":
output = {
"barmode": "overlay",
"bargap": 0.01,
"bargroupgap": 0.2,
"xaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)},
"yaxis": {"title": "Count"},
}
layout.update(
{
"barmode": "overlay",
"bargap": 0.01,
"bargroupgap": 0.2,
"xaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)},
"yaxis": {"title": "Count"},
}
)
elif plot_type == "Box plot":
output = {"yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}}
layout.update({"yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)}})
else:
output = {
"margin": {"l": 40, "r": 40, "b": 30, "t": 10},
"yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)},
"xaxis": {"title": "Realization"},
}
output["height"] = 400
output["colorway"] = colors
return output
layout.update(
{
"margin": {"l": 40, "r": 40, "b": 30, "t": 10},
"yaxis": {"title": VOLUME_TERMINOLOGY.get(response, response)},
"xaxis": {"title": "Realization"},
}
)

# output["colorway"] = colors
return layout


@CACHE.memoize(timeout=CACHE.TIMEOUT)
Expand Down
12 changes: 8 additions & 4 deletions webviz_subsurface/plugins/_inplace_volumes_onebyone.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
self.tornadoplot = TornadoPlot(app, realizations, allow_click=True)
self.uid = uuid4()
self.selectors_id = {x: self.ids(x) for x in self.selectors}
self.plotly_theme = app.webviz_settings["theme"].plotly_theme
self.set_callbacks(app)

def ids(self, element):
Expand Down Expand Up @@ -477,8 +478,12 @@ def _render_vol_chart(plot_type, ensemble, response, source, *filters):
table = calculate_table_rows(data, response)

# Make Plotly figure
layout = {}
layout.update(self.plotly_theme["layout"])
layout.update({"margin": {"l": 100}})
if plot_type == "Per realization":
# One bar per realization
layout.update({"xaxis": {"title": "Realizations"}})
plot_data = data.groupby("REAL").sum().reset_index()
figure = wcc.Graph(
config={"displayModeBar": False},
Expand All @@ -492,11 +497,12 @@ def _render_vol_chart(plot_type, ensemble, response, source, *filters):
"type": "bar",
}
],
"layout": {"xaxis": {"title": "Realizations"}},
"layout": layout,
},
)
elif plot_type == "Box plot":
# One box per sensitivity name
layout.update({"title": "Distribution for each sensitivity"})
figure = wcc.Graph(
config={"displayModeBar": False},
id=self.ids("graph"),
Expand All @@ -511,11 +517,9 @@ def _render_vol_chart(plot_type, ensemble, response, source, *filters):
}
for sensname, dframe in data.groupby(["SENSNAME"])
],
"layout": {"title": "Distribution for each sensitivity"},
"layout": layout,
},
)
else:
print(plot_type)
tornado = json.dumps(
{
"ENSEMBLE": ensemble,
Expand Down
39 changes: 15 additions & 24 deletions webviz_subsurface/plugins/_parameter_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self, app, ensembles, drop_constants: bool = True):
for ens in ensembles
}
self.drop_constants = drop_constants
self.plotly_theme = app.webviz_settings["theme"].plotly_theme

self.uid = uuid4()
self.set_callbacks(app)

Expand Down Expand Up @@ -191,7 +193,9 @@ def _update_matrix(ens, param1, param2):
and it is not possible to assign callbacks to individual
elements of a Plotly graph object
"""
fig = render_matrix(ens, self.drop_constants)
fig = render_matrix(
ens, theme=self.plotly_theme, drop_constants=self.drop_constants
)
# Finds index of the currently selected cell
x_index = list(fig["data"][0]["x"]).index(param1)
y_index = list(fig["data"][0]["y"]).index(param2)
Expand Down Expand Up @@ -223,7 +227,9 @@ def _update_matrix(ens, param1, param2):
],
)
def _update_scatter(ens1, param1, ens2, param2, color, density):
return render_scatter(ens1, param1, ens2, param2, color, density)
return render_scatter(
ens1, param1, ens2, param2, color, density, self.plotly_theme
)

@app.callback(
[
Expand Down Expand Up @@ -264,7 +270,7 @@ def get_parameters(ensemble_path) -> pd.DataFrame:


@CACHE.memoize(timeout=CACHE.TIMEOUT)
def render_scatter(ens1, x_col, ens2, y_col, color, density):
def render_scatter(ens1, x_col, ens2, y_col, color, density, theme):
if ens1 == ens2:
real_text = [f"Realization:{r}" for r in get_parameters(ens1)["REAL"]]
else:
Expand All @@ -285,24 +291,8 @@ def render_scatter(ens1, x_col, ens2, y_col, color, density):
"showlegend": False,
}
)
data.append(
{
"x": x,
"type": "histogram",
"yaxis": "y2",
"showlegend": False,
"marker": {"color": "rgb(31, 119, 180)"},
}
)
data.append(
{
"y": y,
"type": "histogram",
"xaxis": "x2",
"showlegend": False,
"marker": {"color": "rgb(31, 119, 180)"},
}
)
data.append({"x": x, "type": "histogram", "yaxis": "y2", "showlegend": False})
data.append({"y": y, "type": "histogram", "xaxis": "x2", "showlegend": False})
if density:
data.append(
{
Expand All @@ -324,7 +314,6 @@ def render_scatter(ens1, x_col, ens2, y_col, color, density):
],
"contours": {
"coloring": "fill",
# 'end': 80.05,
"showlines": True,
"size": 5,
"start": 5,
Expand All @@ -339,6 +328,7 @@ def render_scatter(ens1, x_col, ens2, y_col, color, density):
layout = {
"margin": {"t": 20, "b": 50, "l": 200, "r": 200},
"bargap": 0.05,
"colorway": theme["layout"]["colorway"],
"xaxis": {
"title": x_col,
"domain": [0, 0.85],
Expand Down Expand Up @@ -394,19 +384,20 @@ def get_corr_data(ensemble_path, drop_constants=True):


@CACHE.memoize(timeout=CACHE.TIMEOUT)
def render_matrix(ensemble_path, drop_constants=True):
def render_matrix(ensemble_path, theme, drop_constants=True):
corrdf = get_corr_data(ensemble_path, drop_constants)
# pylint: disable=no-member
corrdf = corrdf.mask(np.tril(np.ones(corrdf.shape)).astype(np.bool))

data = {
"type": "heatmap",
"x": corrdf.columns,
"y": corrdf.columns,
"z": list(corrdf.values),
"zmin": -1,
"zmax": 1,
"colorscale": theme["layout"]["colorscale"]["sequential"],
}

layout = {
"paper_bgcolor": "rgba(0,0,0,0)",
"plot_bgcolor": "rgba(0,0,0,0)",
Expand Down
Loading

0 comments on commit 148500f

Please sign in to comment.