Skip to content

Commit

Permalink
Figure.legend: Support passing a StringIO object as the legend specif…
Browse files Browse the repository at this point in the history
…ication (#3438)
  • Loading branch information
seisman authored Sep 19, 2024
1 parent 5e5a0c6 commit 5196ae6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pygmt/src/legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
legend - Plot a legend.
"""

import io
import pathlib

from pygmt.clib import Session
Expand Down Expand Up @@ -30,7 +31,7 @@
@kwargs_to_strings(R="sequence", c="sequence_comma", p="sequence")
def legend(
self,
spec: str | pathlib.PurePath | None = None,
spec: str | pathlib.PurePath | io.StringIO | None = None,
position="JTR+jTR+o0.2c",
box="+gwhite+p1p",
**kwargs,
Expand All @@ -57,6 +58,7 @@ def legend(
file
- A string or a :class:`pathlib.PurePath` object pointing to the legend
specification file
- A :class:`io.StringIO` object containing the legend specification.
See :gmt-docs:`legend.html` for the definition of the legend specification.
{projection}
Expand Down Expand Up @@ -89,10 +91,11 @@ def legend(
kwargs["F"] = box

kind = data_kind(spec)
if kind not in {"vectors", "file"}: # kind="vectors" means spec is None
if kind not in {"vectors", "file", "stringio"}: # kind="vectors" means spec is None
raise GMTInvalidInput(f"Unrecognized data type: {type(spec)}")
if kind == "file" and is_nonstr_iter(spec):
raise GMTInvalidInput("Only one legend specification file is allowed.")

with Session() as lib:
lib.call_module(module="legend", args=build_arg_list(kwargs, infile=spec))
with lib.virtualfile_in(data=spec, required_data=False) as vintbl:
lib.call_module(module="legend", args=build_arg_list(kwargs, infile=vintbl))
13 changes: 13 additions & 0 deletions pygmt/tests/test_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Test Figure.legend.
"""

import io
from pathlib import Path

import pytest
Expand Down Expand Up @@ -100,6 +101,18 @@ def test_legend_specfile(legend_spec):
fig = Figure()
fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True)
fig.legend(specfile.name, position="JTM+jCM+w5i")
return fig


@pytest.mark.mpl_image_compare(filename="test_legend_specfile.png")
def test_legend_stringio(legend_spec):
"""
Test passing a legend specification via an io.StringIO object.
"""
spec = io.StringIO(legend_spec)
fig = Figure()
fig.basemap(projection="x6i", region=[0, 1, 0, 1], frame=True)
fig.legend(spec, position="JTM+jCM+w5i")
return fig


Expand Down

0 comments on commit 5196ae6

Please sign in to comment.