diff --git a/cand/canvas.py b/cand/canvas.py index 30ec1e4..078c462 100644 --- a/cand/canvas.py +++ b/cand/canvas.py @@ -339,6 +339,8 @@ def convert_to_absolute_coord(self, point): return Point(point.x*self.size[0], point.y*self.size[1], "absolute") if point.coordinate == "-absolute": return Point(self.size[0]-point.x, self.size[1]-point.y, "absolute") + if point.coordinate in ["Msize", "fontsize"]: # Msize for backward compatibility + return self.convert_to_absolute_coord(Point(point.x*self.fontsize, point.y*self.fontsize, "point")) if point.coordinate in self.axes.keys(): # The call to autoscale_view fix the problem that automatic data # limits are updated lazily, and thus, gives an outdated transData @@ -588,9 +590,6 @@ def add_legend(self, pos_tl, els, fontsize=None, line_spacing=Height(2.2, "Msize assert len(els) >= 1 # Get the text height fprops = self._get_font() - t = matplotlib.textpath.TextPath((0,0), "M", size=fontsize, prop=fprops) - if self.is_valid_identifier("Msize"): - self.add_unit("Msize", Vector(self.fontsize, self.fontsize, "point")) # All params are in units of M width or height padding_top = Height(0, "Msize") # Space on top of figure padding_left = Width(0, "Msize") # Space on left of lines diff --git a/doc/_static/images/ddmdiagram.png b/doc/_static/images/ddmdiagram.png new file mode 100644 index 0000000..61ac4fb Binary files /dev/null and b/doc/_static/images/ddmdiagram.png differ diff --git a/doc/_static/images/fokkerplanck.png b/doc/_static/images/fokkerplanck.png new file mode 100644 index 0000000..bc99838 Binary files /dev/null and b/doc/_static/images/fokkerplanck.png differ diff --git a/doc/downloads/ddmdiagram.py b/doc/downloads/ddmdiagram.py new file mode 100644 index 0000000..020bae0 --- /dev/null +++ b/doc/downloads/ddmdiagram.py @@ -0,0 +1,486 @@ +# Begin preamble +from ddm import Model, DriftConstant +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import scipy.stats +import random + +sns.set_palette("Set1") +color_bound = sns.color_palette()[0] +color_x0 = sns.color_palette()[1] +color_drift = sns.color_palette()[2] +color_ev = sns.color_palette()[3] +color_nd = sns.color_palette()[4] +# End preamble + +# Begin plot function +def set_up_plot(ax): + sns.despine(bottom=True, ax=ax) + ax.axhline(0, c="gray", linestyle="--") + ax.set_xlim(0, 0.55) + ax.set_ylim(-1, 1) + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines["left"].set_linewidth(2) +# End plot function + + +# Begin canvas initialization +from cand import Canvas, Width, Height, Point, Vector + +c = Canvas(4, 9) +c.set_font("Nimbus sans", size=8) +# End canvas initialization + +##### DDM SCHEMATIC ##### + +# Begin default unit +c.add_unit("absolute2", Vector(1, 1, "absolute"), Point(0, 1, "absolute")) +c.set_default_unit("absolute2") +# End default unit + +# Begin other axes +c.add_axis("ddm", Point(0.3, 6.5), Point(3.7, 7.6)) +c.add_axis("evidence", Point(0.3, 6.2), Point(3.7, 6.4)) +# End other axes + +# Begin model simulation +model = Model(drift=DriftConstant(drift=0.8), dt=0.01, dx=0.01) +sim_trial = model.simulate_trial(seed=8) +sim_trial[20:] = -sim_trial[20:] * 1.6 +sim_trial = sim_trial[0:38] +# End model simulation + +# Begin ddm plot +ax = c.ax("ddm") +ax.plot(model.t_domain()[0 : len(sim_trial)], sim_trial, c="k") +ax.axhline(1, c=color_bound, clip_on=False) +ax.axhline(-1, c=color_bound, clip_on=False) +ax.set_ylabel("Decision variable") +set_up_plot(ax) +# End ddm plot + +# Begin ddm annotate +non_decision_time = 0.15 +ndstartpoint = Point(model.t_domain()[37], -0.1, "ddm") +c.add_arrow( + ndstartpoint, + ndstartpoint + Vector(non_decision_time, 0, "ddm"), + arrowstyle="|-|,widthA=5,widthB=5", + color=color_nd, +) +c.add_text( + "Non-\ndecision\ntime", + ndstartpoint + Vector(non_decision_time / 2, -0.15, "ddm"), + verticalalignment="top", + color=color_nd, +) +ndgap = 0.003 +# End ddm annotate + +# Begin ddm bounds label +c.add_text("Bounds", Point(0.8, 0.93, "axis_ddm"), color=color_bound) +# End ddm bounds label + +# Begin drift rate arrow +arrow_start = Point(0.17, 0.2, "ddm") +arrow_len = Vector(0.3, 0.3, "absolute") +c.add_arrow( + arrow_start, + arrow_start + arrow_len, + arrowstyle="-|>,head_width=4,head_length=8", + lw=3, + color=color_drift, + joinstyle="miter", +) +c.add_text( + "Drift", + arrow_start + arrow_len.height(), + verticalalignment="top", + color=color_drift, +) +# End drift rate arrow + +# Begin x0 label +ax.scatter([0], [0], color=color_x0, s=70, clip_on=False) +c.add_text( + "Starting\npoint", + Point(0.01, 0.07, "ddm"), + horizontalalignment="left", + verticalalignment="bottom", + color=color_x0, +) +# End x0 label + +# Begin label evidence over time +ax = c.ax("evidence") +ax.plot([0, 1], [1, 1], c=color_ev) + + +def make_evgrid_axis(ax): + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_xlim([0, 1]) + ax.spines["left"].set_linewidth(2) + ax.spines["bottom"].set_linewidth(2) + + +make_evgrid_axis(ax) +ax.set_xlabel("Time") +ax.set_ylabel("Evidence") +# End label evidence over time + +#################### Full DDM #################### + +# Begin set up full ddm +c.add_grid( + ["fullx0", "fulldrift", "fullnd"], + 1, + Point(0.3, 4.3), + Point(3.7, 5.3), + size=Vector(0.7, 1), +) +c.add_axis("evidence_ddm", Point(0.3, 4.0), Point(3.7, 4.2)) +# End set up full ddm + +# Begin make starting pos +def make_starting_position_plot(axname, dist_func): + zoom_area = (-0.9, 0.9) + ax = c.ax(axname) + dist_x = np.linspace(zoom_area[0], zoom_area[1], 500) + ax.fill_betweenx( + dist_x, -(dist_func(dist_x)) / 70, 0, clip_on=False, color=color_x0 + ) + set_up_plot(ax) + ax.set_ylim(*zoom_area) + ax.set_xlim(0, 0.3) + offset = 0.25 + ax.plot(model.t_domain()[0 : len(sim_trial)] * 2, offset + sim_trial, c="k") + + +make_starting_position_plot("fullx0", scipy.stats.uniform(-0.05, 0.4).pdf) +c.add_text( + "Uniformly-distributed\nstarting point", + Point(0, 0, "fullx0") + Vector(0.1, -0.1, "absolute"), + horizontalalignment="left", + verticalalignment="top", +) +# End make starting pos + +# Begin make nd time +def make_ndtime(axname, dist_func, height=0.05): + ax = c.ax(axname) + set_up_plot(ax) + ax.set_xlim(0, 0.3) + ax.set_ylim(-0.3, 0.3) + c.add_axis( + axname + "dist", + Point(0.05 + non_decision_time / 2, -0.15, axname), + Point(0.05 + non_decision_time * 1.5, -0.15 + height, axname), + ) + c.ax(axname + "dist").axis("off") + nddist_x = np.linspace(-1, 1, 500) + c.ax(axname + "dist").fill_between( + nddist_x, dist_func(nddist_x) / 50, 0, color=color_nd + ) + c.ax(axname + "dist").set_xlim(-1, 1) + sns.despine(ax=ax, left=True, bottom=True) + c.add_arrow( + Point(0.05 + ndgap, -0.05, axname), + Point(0.05 + non_decision_time - ndgap, -0.05, axname), + arrowstyle="|-|,widthA=5,widthB=0", + color=color_nd, + ) + + +make_ndtime("fullnd", scipy.stats.uniform(-0.8, 1.6).pdf) +c.add_text( + "Uniformly-distributed\nnon-decision time", + Point(1, 0.5, "axis_fullnd") + Vector(0, 0.1, "absolute"), + horizontalalignment="right", + verticalalignment="bottom", +) +# End make nd time + +# Begin full drift rate +ax = c.ax("fulldrift") +set_up_plot(ax) +arrgapx = 0.05 +arrgapy = 0.05 +anchor = 0.15 +arrlen = Vector(0.3, 0.3, "absolute") +ax.set_ylim(-0.5, 0.5) +ax.set_xlim(anchor - 0.01, anchor + 0.1) +arrowtilt = Vector(0, 0) +arrow_lr = Point(anchor + arrgapx, arrgapy, "fulldrift") +c.add_arrow( + arrow_lr, + arrow_lr + arrlen + arrowtilt, + arrowstyle="-|>,head_width=4,head_length=8", + lw=3, + color=color_drift, + joinstyle="miter", +) +sns.despine(ax=ax, left=True, bottom=True) +drift_x = np.linspace(-1, 1, 500) +drift_mat = np.asarray([drift_x, scipy.stats.norm(0, 0.3).pdf(drift_x) / 3]).T +cos, sin = np.cos(3.141592 / 4), np.sin(3.141592 / 4) +rotated_drift_mat = drift_mat @ np.array([[cos, -sin], [sin, cos]]) +c.add_axis("miniarrowdist", arrow_lr - arrlen * 0.7, arrow_lr + arrlen * 0.7) +c.ax("miniarrowdist").axis("off") +c.ax("miniarrowdist").fill_between( + rotated_drift_mat[:, 0], + rotated_drift_mat[:, 1], + -rotated_drift_mat[:, 0], + color=color_drift, +) +c.add_text( + "Gaussian-\ndistributed\ndrift rate", + Point(anchor, 0, "fulldrift") + Vector(0, -0.13, "absolute"), + horizontalalignment="left", + verticalalignment="top", +) +# End full drift rate + +# Begin evidence ddm +ax = c.ax("evidence_ddm") +ax.plot([0, 1], [1, 1], c=color_ev) +make_evgrid_axis(ax) +ax.set_xlabel("Time") +ax.set_ylabel("Evidence") +# End evidence ddm + +#################### GDDM #################### + +# Begin gddm grid +c.add_grid( + ["gddmx0", "gddmdrift", "gddmnd", "gddmbounds", "gddmleaky", "gddmunstable"], + 2, + Point(0.3, 0.7), + Point(3.7, 3.0), + size=Vector(0.7, 1), +) +# End gddm grid + +# Begin gddm starting pos +fancy_distribution = lambda dist_x: 0.7 * ( + scipy.stats.norm(0.1, 0.05).pdf(dist_x) + + scipy.stats.norm(0.35, 0.12).pdf(dist_x) + + 0.5 * scipy.stats.norm(0.6, 0.05).pdf(dist_x) +) +make_starting_position_plot("gddmx0", fancy_distribution) +c.add_text( + "Any distribution for\nstarting point", + Point(0, 0, "gddmx0") + Vector(0.1, -0.1, "absolute"), + horizontalalignment="left", + verticalalignment="top", +) +# End gddm starting pos + +# Begin gddm nd time +fancy_distribution2 = lambda dist_x: scipy.stats.norm(0.5, 0.19).pdf( + dist_x +) + scipy.stats.norm(-0.2, 0.3).pdf(dist_x) +make_ndtime("gddmnd", fancy_distribution2, height=0.12) +c.add_text( + "Any distribution for\nnon-decision time", + Point(1, 0.5, "axis_gddmnd") + Vector(0, 0.1, "absolute"), + horizontalalignment="right", + verticalalignment="bottom", +) +# End gddm nd time + +# Begin gddm drift +ax = c.ax("gddmdrift") +set_up_plot(ax) + +x = np.linspace(-2, 2, 1000) +y = np.sin(2 * np.pi * x * 0.5 + 0.2) * 1 / ((3 + np.abs(x ** 8))) +x += 0.7 +y += 0.4 +rotated_drift_mat = np.asarray([x, y]).T @ np.array([[cos, sin], [-sin, cos]]) * 1 +ax.plot(rotated_drift_mat[:, 0], rotated_drift_mat[:, 1], c=color_drift, lw=3) +arrowstart = Point(rotated_drift_mat[-1, 0], rotated_drift_mat[-1, 1], "gddmdrift") +sns.despine(ax=ax, left=True, bottom=True) +ax.set_xlim(-2, 3) +ax.set_ylim(-3.5, 3.5) +c.add_arrow( + arrowstart, + arrowstart + Vector(0.12, 0.12, "absolute"), + arrowstyle="-|>,head_width=4,head_length=8", + lw=3, + color=color_drift, + joinstyle="miter", +) +c.add_text( + "Arbitrary\ntime-varying\ndrift rate", + Point(non_decision_time, 0, "gddmdrift") + Vector(0, -0.05, "absolute"), + horizontalalignment="left", + verticalalignment="top", +) +# End gddm drift + +# Begin gddm bounds +ax = c.ax("gddmbounds") +set_up_plot(ax) +bounds_x = np.linspace(0, 1.5, 500) +l = 0.9 +k = 3 +a = 1 +aprime = 0.1 +bounds_y = a - (1 - np.exp(-((bounds_x / l) ** k))) * (a - aprime) +ax.plot(bounds_x, bounds_y, c=color_bound) +ax.plot(bounds_x, -bounds_y, c=color_bound) +ax.set_ylim(-1.1, 1.1) +ax.set_xlim(0, 1.5) +c.add_text( + "Time-varying\nbounds", + Point(1, 1, "axis_gddmbounds") + Vector(-0.35, 0.1, "absolute"), + horizontalalignment="left", + verticalalignment="top", +) +# End gddm bounds + +# Begin leaky +ax = c.ax("gddmleaky") +set_up_plot(ax) +xdom = np.linspace(0, 0.3, 50) +xpos = 0.1 + xdom +ypos = 0.8 * np.exp(-xdom / 0.2) +size = Vector(0.1, 0.1, "axis_gddmleaky") +for x in [0.1, 0.3, 0.5, 0.7, 0.9]: + for y in [0.1, 0.35]: + c.add_arrow( + Point(x, y, "axis_gddmleaky") - size / 2, + Point(x, y, "axis_gddmleaky") + size / 2, + linewidth=1, + arrowstyle="-|>,head_width=1,head_length=1", + color="k", + ) + for y in [0.65, 0.9]: + c.add_arrow( + Point(x, y, "axis_gddmleaky") - size.flipy() / 2, + Point(x, y, "axis_gddmleaky") + size.flipy() / 2, + linewidth=1, + arrowstyle="-|>,head_width=1,head_length=1", + color="k", + ) +c.add_text( + "Leaky integration", + Point(0.5, 1.05, "axis_gddmleaky"), + horizontalalignment="center", + verticalalignment="bottom", +) +# End leaky + +# Begin unstable +ax = c.ax("gddmunstable") +set_up_plot(ax) +xdom = np.linspace(0, 0.3, 50) +xpos = 0.1 + xdom +ypos = 0.8 * np.exp(-xdom / 0.2) +size = Vector(0.1, 0.1, "axis_gddmunstable") +for x in [0.1, 0.3, 0.5, 0.7, 0.9]: + for y in [0.1, 0.35]: + c.add_arrow( + Point(x, y, "axis_gddmunstable") - size.flipy() / 2, + Point(x, y, "axis_gddmunstable") + size.flipy() / 2, + linewidth=1, + arrowstyle="-|>,head_width=1,head_length=1", + color="k", + ) + for y in [0.65, 0.9]: + c.add_arrow( + Point(x, y, "axis_gddmunstable") - size / 2, + Point(x, y, "axis_gddmunstable") + size / 2, + linewidth=1, + arrowstyle="-|>,head_width=1,head_length=1", + color="k", + ) +c.add_text( + "Unstable integration", + Point(0.5, 1.05, "axis_gddmunstable"), + horizontalalignment="center", + verticalalignment="bottom", +) +# End unstable + +# Begin evidence grid +pos_ll = Point(0.3, 0.3, "absolute") +pos_ur = Point(3.7, 1.5, "absolute") +evgrids = ["gddmev1", "gddmev2", "gddmev3", "gddmev4"] +c.add_grid( + evgrids, + 4, + pos_ll, + pos_ur, + size_x=(pos_ur - pos_ll).width(), + size_y=Height(0.2, "absolute"), +) +for eg in evgrids: + make_evgrid_axis(c.ax(eg)) + sns.despine(ax=c.ax(eg)) +# End evidence grid + +# Begin gddm evidence +c.add_text("Evidence", (pos_ll | (pos_ur << pos_ll)) - Width(0.1), rotation=90) +c.ax(evgrids[-1]).set_xlabel("Time") +evs_x = np.linspace(0, 1, 1000) + +# Crazy evidence +random.seed(4) +spikes = 40 * [1] + 992 * [0] +random.shuffle(spikes) +c.ax(evgrids[0]).plot( + evs_x, + scipy.stats.gaussian_kde( + [i / 1000 for i, s in enumerate(spikes) if s], bw_method=0.1 + )(evs_x), + c=color_ev, +) + +# Poisson evidence +random.seed(2) +spikes = 8 * [1] + 992 * [0] +random.shuffle(spikes) +c.ax(evgrids[1]).plot(evs_x, spikes, c=color_ev) + +# Pulse evidence +y = [1 if x > 0.2 and x < 0.35 else 0.4 for x in evs_x] +c.ax(evgrids[2]).plot(evs_x, y, c=color_ev) +c.ax(evgrids[2]).set_ylim(0, 1.1) + +c.add_text( + "Any form of\nevidence", + Point(1, 1, "axis_" + evgrids[0]) + Vector(0.08, 0.08, "absolute"), + horizontalalignment="right", + verticalalignment="top", +) + +# Changing evidence +N = 40 +np.random.seed(3) +xs = [0] + list(np.repeat(range(0, N), 2)) + [N] +step_heights = np.random.beta(0.3, 0.3, N + 1) +ys = np.repeat(step_heights, 2) +c.ax(evgrids[3]).plot(np.asarray(xs) / N, ys, c=color_ev) +# End gddm evidence + +# Begin section labels +c.add_text("DDM", Point(0.15, 7.8), weight="bold", size=10, horizontalalignment="left") +c.add_text( + "Full DDM", Point(0.15, 5.5), weight="bold", size=10, horizontalalignment="left" +) +c.add_text( + "GDDM (examples)", + Point(0.15, 3.2), + weight="bold", + size=10, + horizontalalignment="left", +) +# End section labels + +# Begin save +c.save("ddmdiagram.png") +# End save diff --git a/doc/downloads/fokkerplanck.py b/doc/downloads/fokkerplanck.py new file mode 100644 index 0000000..62e02a0 --- /dev/null +++ b/doc/downloads/fokkerplanck.py @@ -0,0 +1,198 @@ +# Begin imports +import ddm # Requires PyDDM to be installed +from ddm import * +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from cand import Canvas, Point, Width, Height, Vector +# End imports + +# Begin set up canvas and axes +c = Canvas(4.85, 1.5) +c.set_font("Nimbus Sans", size=8) +c.set_default_unit("absolute") + +c.add_axis("trial", Point(0.3, 0.3), Point(1.3, 1)) +c.add_axis("trials", Point(1.7, 0.3), Point(2.7, 1)) +c.add_axis("fp", Point(3.3, 0.3), Point(4.3, 1)) +c.add_arrow(Point(1.05, 0.5, "trial"), Point(-0.05, 0.5, "trials")) +c.add_text( + "DV", + Point(0, 0.5, "axis_trial") - Width(0.05, "absolute"), + horizontalalignment="right", + verticalalignment="center", + rotation="vertical", +) +c.add_text( + "Time", + Point(0.5, 0, "axis_trial") - Height(0.15, "absolute"), + horizontalalignment="center", + verticalalignment="center", +) +c.add_text( + "Time", + Point(0.5, 0, "axis_trials") - Height(0.15, "absolute"), + horizontalalignment="center", + verticalalignment="center", +) +c.add_figure_labels([("a", "trial"), ("b", "fp")]) +# End set up canvas and axes + +# Begin top bottom axes +def add_hists(axname, shift=True): + c.add_axis( + axname + "_top", + Point(0, (1.04 if shift else 1), "axis_" + axname), + Point(1, 1.3, "axis_" + axname), + ) + c.add_axis( + axname + "_bot", + Point(0, -0.3, "axis_" + axname), + Point(1, (-0.04 if shift else 0), "axis_" + axname), + ) + c.ax(axname).axis("off") + c.ax(axname + "_top").axis("off") + c.ax(axname + "_bot").axis("off") +# End top bottom axes + +# Begin finalize hist +def finalize_hists(axname): + c.ax(axname + "_bot").set_ylim(c.ax(axname + "_top").get_ylim()) + c.ax(axname + "_bot").invert_yaxis() +# End finalize hist + + +# Begin model +T_dur = 2 +model = Model(drift=ddm.DriftConstant(drift=0.8), dt=0.01, dx=0.01) +# End model + +# Begin create trajectories +def draw_trials(model, axname, N=1, seedstart=7, alpha=1): + # Create three-rowed figure + add_hists(axname) + + # Draw DDM axes + ax = c.ax(axname) + ax.plot([0, T_dur], [0, 0], c="gray", clip_on=False, lw=0.5) + ax.plot([0, T_dur], [1.04, 1.04], c="red", clip_on=False, lw=2) + ax.plot([0, T_dur], [-1.04, -1.04], c="red", clip_on=False, lw=2) + # sns.despine(bottom=True, ax=ax_main) + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.axis([0, T_dur, -1, 1]) + + # Draw paths and the corresponding histogram + corr_times = [] + err_times = [] + for seed in range(seedstart, N + seedstart): + Y = model.simulate_trial(seed=seed) + X = model.t_domain() + X = X[0 : len(Y)] + if Y[-1] > 1: + corr_times.append(X[len(Y) - 1]) + elif Y[-1] < -1: + err_times.append(X[len(Y) - 1]) + ax.plot(X, Y, linewidth=0.3, c="k", alpha=alpha) + + c.ax(axname + "_top").hist(corr_times, bins=41, range=(0, model.T_dur)) + c.ax(axname + "_bot").hist(err_times, bins=41, range=(0, model.T_dur)) + finalize_hists(axname) + c.ax(axname + "_top").set_xlim(-0.025, T_dur + 0.025) + c.ax(axname + "_bot").set_xlim(-0.025, T_dur + 0.025) + + +# One trial +draw_trials(model, "trial", N=1, seedstart=8) + +# Several trials +# Note that this step is slow +draw_trials(model, "trials", N=400, seedstart=0, alpha=0.25) +# End create trajectories + + +# Begin fp grid +add_hists("fp", shift=False) +s = model.solve_numerical_implicit(return_evolution=True) +grid = np.sqrt(s.pdf_evolution()) +# End fp grid + +# Begin top/bottom pdf +s = model.solve_numerical_implicit() +top = s.pdf_corr() +bot = s.pdf_err() +# Show the relevant data on those axes +c.ax("fp").imshow( + np.log10(grid ** 2 + 1e-5), + aspect="auto", + interpolation="bicubic", + cmap="inferno", + vmin=-2, + vmax=0, +) +c.ax("fp").invert_yaxis() +c.ax("fp_top").plot(np.linspace(0, len(top) - 1, len(top)), top, clip_on=False) +c.ax("fp_bot").plot(np.linspace(0, len(top) - 1, len(bot)), bot, clip_on=False) +c.add_text( + "DV", + Point(0, 0.5, "axis_fp") - Width(0.05, "absolute"), + horizontalalignment="right", + verticalalignment="center", + rotation="vertical", +) +c.add_text( + "Time", + Point(0.5, 0, "axis_fp") - Height(0.15, "absolute"), + horizontalalignment="center", + verticalalignment="center", +) +axsize = plt.axis() +c.ax("fp").plot( + [-0.35, len(grid[0, :]) - 0.65], [1.0, 1.0], c="red", clip_on=False, lw=2 +) +c.ax("fp").plot( + [-0.35, len(grid[0, :]) - 0.65], + [len(grid[:, 0]) - 1, len(grid[:, 0]) - 1], + c="red", + clip_on=False, + lw=2, +) +plt.axis(axsize) +# Set axes to be the right size +finalize_hists("fp") +# End top/bottom pdf + +# Begin colorbar +cb_norm = plt.matplotlib.colors.LogNorm(vmin=1e-2, vmax=1) +cb = c.add_colorbar( + "cbar", + Point(1.1, 0, "axis_fp"), + Point(1.15, 1, "axis_fp"), + cmap="inferno", + bounds=cb_norm, +) +cb.set_ticks([0.01, 0.1, 1]) +c.ax("cbar").tick_params(axis="y", which="minor", right="off") +c.add_text("Probability", Point(3, 1.2, "axis_cbar")) +# End colorbar + +# Begin finalize +c.add_text( + "Trial-wise trajectory simulation", + (Point(0, 1, "axis_trials") | Point(1, 1, "axis_trial")) + + Height(0.3), + horizontalalignment="center", + verticalalignment="center", + weight="bold", +) + +c.add_text( + "Fokker-Planck", + Point(0.5, 1, "axis_fp") + Height(0.3), + horizontalalignment="center", + verticalalignment="center", + weight="bold", +) + +c.save("fokkerplanck.png") +# End finalize diff --git a/doc/gallery/ddmdiagram.rst b/doc/gallery/ddmdiagram.rst new file mode 100644 index 0000000..f56428c --- /dev/null +++ b/doc/gallery/ddmdiagram.rst @@ -0,0 +1,257 @@ +Generalized Drift Diffusion Model diagram +========================================= + +Summary +------- + +Here, we will show how to build each individual component found in the diagram +describing the generalized drift diffusion model (GDDM), as seen in `Figure 1 +`_ from `Shinn et +al. 2020 `_. This is a good +demonstration because it does not require data files to produce. + +Setting up the figure +~~~~~~~~~~~~~~~~~~~~~ + +First, we import the plotting libraries and define some basic properties: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin preamble + :end-before: # End preamble + +Define a function we will use later to make the plotting look nice: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin plot function + :end-before: # End plot function + + +Now, we initialize the Canvas object within CanD: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin canvas initialization + :end-before: # End canvas initialization + +DDM schematic +~~~~~~~~~~~~~ + +We create a new unit, which we call "absolute2". We do this so that we can use +a separate coordinate system for each part of the plot. This makes it easy to +adjust the position of the different parts separately. Then, we set it as the +default unit so that we don't have to worry about adding "absolute2" to the end +of all of our Points: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin default unit + :end-before: # End default unit + +We add two axes, one for the actual DDM, and one for the plot showing the change +in evidence over time: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin other axes + :end-before: # End other axes + +First, we simulate the model to create the diagram. Note that we make some +cosmetic changes to the DDM trace for the purpose of clarity within the +diagram: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin model simulation + :end-before: # End model simulation + +Next, we plot the trace on the axis, and draw the upper and lower bounds on the +same axis: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin ddm plot + :end-before: # End ddm plot + +Now, we annotate the DDM diagram. First, we draw an an arrow to denote the +non-decision time. We make the arrow a distance indicator, and change the shape +to make it attractive. We position it using the non-decision time startpoint, +determined from the coordinates within the axis: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin ddm annotate + :end-before: # End ddm annotate + +Label the bounds: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin ddm bounds label + :end-before: # End ddm bounds label + +Draw an arrow indicating the drift rate, and label it: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin drift rate arrow + :end-before: # End drift rate arrow + +Draw and label the starting point: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin x0 label + :end-before: # End x0 label + +Draw and label the evidence over time: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin label evidence over time + :end-before: # End label evidence over time + +Full DDM schematic +~~~~~~~~~~~~~~~~~~ + +First, we set up a grid of axes for the starting point, variable drift rate, and +non-decision time features of the Full DDM, as well as the evidence over time +axis: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin set up full ddm + :end-before: # End set up full ddm + +The following will allow us to draw any distribution sideways on an axis. We +write this as a function because, while it is simply a uniform distribution +here, we will want to draw a more complicated distribution later for the GDDM. +This functions by transforming the points to data space, setting the clip off, +and then filling in the region between the curve and the axis. This is mostly +performed in pure matplotlib, and does not require features from CanD: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin make starting pos + :end-before: # End make starting pos + +We also draw an arbitrary distribution for the non-decision time. By contrast, +this uses some features of CanD. It operates by creating a new axis on top of +the existing axis, turning off the spines, and drawing the distribution on the +new axis. Then, it uses CanD to draw an "arrow" (that looks like "|---"), +similar to the one drawn on the DDM plot: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin make nd time + :end-before: # End make nd time + +Lastly, we draw the drift rate. We seek to draw an arrow with a Gaussian +distribution at the bottom. We accomplish this by defining a Gaussian +distribution, rotating it by 45 degrees, placing it on a new axis. We set the +position of the new axis based on the position of the arrow: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin full drift rate + :end-before: # End full drift rate + +And, as before, we plot the constant evidence signal: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin evidence ddm + :end-before: # End evidence ddm + +GDDM +~~~~ + +First, we work on the six main plots showing GDDM features. We will look at the +evidence streams later. As in the Full DDM case, create a grid of axes: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin gddm grid + :end-before: # End gddm grid + +We can reuse the code from the Full DDM to easily plot the starting position +distribution: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin gddm starting pos + :end-before: # End gddm starting pos + +and the non-decision time distribution: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin gddm nd time + :end-before: # End gddm nd time + +To form the curved arrow in the drift rate plot, we create a sine-like wave +showing what we want, and then rotate it by 45 degrees, the same way we did +above. We use the rotated coordinates as the start of the arrow, and then make +the arrow short. We need to set the xlim and ylim before drawing the arrow but +after plotting the function, or else matplotlib may readjust the axis limits +after the arrow has been drawn. + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin gddm drift + :end-before: # End gddm drift + +We draw bounds as red lines which converge to the center. This is mostly pure +matplotlib. + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin gddm bounds + :end-before: # End gddm bounds + +Draw a grid of arrows. We want the arrows to face up when below the midpoint +and down when above it. So we loop through x coordinates and y coordinates in a +grid. + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin leaky + :end-before: # End leaky + +We do the same thing, except with arrows facing the other way. + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin unstable + :end-before: # End unstable + +Create a grid of axes to use for the different evidence streams: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin evidence grid + :end-before: # End evidence grid + +Now plot the streams on these axes: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin gddm evidence + :end-before: # End gddm evidence + +Add section labels: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin section labels + :end-before: # End section labels + +Finally, save the figure: + +.. literalinclude:: ../downloads/ddmdiagram.py + :language: python + :start-after: # Begin save + :end-before: # End save + +.. image:: ../_static/images/ddmdiagram.png diff --git a/doc/gallery/fokkerplanck.rst b/doc/gallery/fokkerplanck.rst new file mode 100644 index 0000000..aa6ab35 --- /dev/null +++ b/doc/gallery/fokkerplanck.rst @@ -0,0 +1,103 @@ +Fokker-Planck diagram +===================== + +Summary +------- + +Here, we will show how to build the schematic for Fokker-Planck, as seen in `Figure 3 +`_ from `Shinn et +al. 2020 `_. + +Setting up the figure +~~~~~~~~~~~~~~~~~~~~~ + +First, we import the plotting libraries and define some basic properties: + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin imports + :end-before: # End imports + +Next, we set up the canvas and define the three main axes we will use, the arrow +in between, and some basic axis labels. Note that this does not yet define the +histograms/distributions shown on the top and bottom of these axes. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin set up canvas and axes + :end-before: # End set up canvas and axes + +Now, we create a function which will add histograms or plots of the pdf on the +top and bottom of these axes. We set it up as a function which accepts the name +of the axis. We use the "shift" argument to fine tune the positioning, since +matplotlib misaligns the histograms for some reason. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin finalize hist + :end-before: # End finalize hist + +This function also operates on an axis. It runs formatting functions which must +be applied after everything is already plotted, such as flipping the bottom axis +on the lower histogram. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin top bottom axes + :end-before: # End top bottom axes + +Define the model: + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin model + :end-before: # End model + +Plot trajectories on the axis. We abstract this into a function so that we can +run it twice, once for the plot with only one trajectory, and once for the plot +with multiple trajectories. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin create trajectories + :end-before: # End create trajectories + +Now we build the heatmap (grid) showing the evolution of the drift diffusion +model under Fokker-Planck based methods. We build it into a function and call +it once for the correct axis. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin fp grid + :end-before: # End fp grid + +Now, we create the distributions on the top and bottom of the heatmap, along +with the appropriate labels. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin top/bottom pdf + :end-before: # End top/bottom pdf + +Additionally, we create a colorbar. We can't use the "vmin" and "vmax" +arguments of the :meth:`.Canvas.add_colorbar` method because we want to create a +colorbar with log scaling. Thus, we have to create a `matplotlib LogNorm +`_ +object. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin colorbar + :end-before: # End colorbar + +Finally, we add labels and save the figure. + +Note that we determine the center position between the two leftmost plots using +the "|" operator, which evaluates to be the center between any two Points. + +.. literalinclude:: ../downloads/fokkerplanck.py + :language: python + :start-after: # Begin colorbar + :end-before: # End colorbar + +.. image:: ../_static/images/fokkerplanck.png diff --git a/doc/gallery/index.rst b/doc/gallery/index.rst new file mode 100644 index 0000000..03645a9 --- /dev/null +++ b/doc/gallery/index.rst @@ -0,0 +1,22 @@ +Gallery +======= + +The following are examples showing the use of CanD in real publications. They +will include quite a bit of non-CanD code as well, to illustrate how CanD can be +used in real life to create specific effects. + + +.. image:: ../_static/images/fokkerplanck.png + :target: fokkerplanck.html + +.. image:: ../_static/images/ddmdiagram.png + :target: ddmdiagram.html + +:ref:`modindex` + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + ddmdiagram + fokkerplanck diff --git a/doc/index.rst b/doc/index.rst index a284e96..b15dea4 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -16,6 +16,7 @@ problems to the `bug tracker `_. installing tutorial + gallery/index apidoc/index faqs contact diff --git a/doc/tutorial.rst b/doc/tutorial.rst index adf50f1..f0159f3 100644 --- a/doc/tutorial.rst +++ b/doc/tutorial.rst @@ -82,6 +82,12 @@ which may be used to position elements on that canvas. These include: direction. - "-figure": Indentical to "figure", except with the origin in the upper right corner. +- "fontsize": The default font size, e.g. if you have 8pt font, 2 units is 16pt. + See below for how to change the font size. Note that this is computed by + converting the size in points to the size on the figure. Thus, you cannot + rely on this being exactly the same distance as a given font, because + different fonts use different design choices. Nevertheless, it should be + close. - "default" (or none): The default unit. This is set to "absolute" when the canvas is created, but this can be changed. For example, ``Point(1, 2, "default")`` and ``Point(1, 2)`` are equivalent. @@ -133,7 +139,7 @@ the canvas is 0, the upper right corner is (1,1), and the lower left corner is Vector/Point arithmetic ---------------------------- +----------------------- It is possible to perform arithmetic on vectors, similar to the way we perform vector operations in linear algebra. Two vectors can be added and subtracted, @@ -568,7 +574,46 @@ possible to overlay other plot elements on top of images. Plot elements ------------- -add_legend, add_colorbar, add_figure_labels +CanD implements its own helper functions for several plot features. Some of +these are reimplemented from matplotlib features. It is still possible to use +the original matplotlib versions, but in many cases, the versions implemented by +CanD will be simpler. + +To add a legend, use the :meth:`.Canvas.add_legend` function. The first +argument ``pos_tl`` is the position of the top left corner, and the second +argument ``els`` is a list with a specific format to describe the content of the +legend. Each element of the list should be a tuple, where the first element is +the title, and the second element is a dictionary to describe the style. The +elements of this dictionary should correspond to those passed to the +:meth:`.Canvas.add_line` or :meth:`.Canvas.add_marker` functions. To use a +marker instead of a line, set linestyle to the string "None". You can +optionally pass additional arguments to control the spacing of the different +aspects of the legend. `line_spacing` determines spacing between each line of +descriptive text in the legend. `sym_width` is the width of the symbols (lines +and markers). `padding_sep` is the separation between the symbols and the +descriptive text. + +To add a colorbar, use :meth:`.Canvas.add_colorbar`. The first argument is the +name of the colorbar. This should be unique, and should not coincide with the +name of an axis, because this will be usable as a unit. The following two +arguments are the bottom left and upper right corners of the colorbar. The next +argument is a tuple containing the minimum and maximum value of the colorbar. +This colorbar function does not automatically map to a matplotlib axis, so the +axis limits (vmin and vmax) will have to be manually specified in both cases. +All remaining optional arguments are identical to those of matplotlib's +`ColorbarBase +`_, +notably, ``cmap``, which takes the name of a colormap to use for the colorbar. +Orientation is determined automatically. + +Additionally, labels can be added in a consistent manner with +:meth:`.Canvas.add_figure_labels`. The first argument is a list of tuples +describing the labels to add. The first element of each tuple is the text to +use for the label, such as "a", "b", etc. The second element of each tuple is +the name of the axis to which to add the label. The third element of the tuple +is optional, and specifies an offset in the position. Following the argument, +:meth:`.Canvas.add_figure_labels` function also takes an optional second +argument specifying the font size of the labels. Grids of axes -------------