Skip to content

Commit

Permalink
refactor!: mark some arguments as positional-only in pybind11
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Oct 24, 2024
1 parent 1f13b82 commit ae89038
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
hooks:
- id: black
- repo: https://github.com/asottile/pyupgrade
rev: v3.18.0
rev: v3.19.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand Down
39 changes: 27 additions & 12 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ limitations under the License.
#include <optional> // std::optional, std::nullopt
#include <string> // std::string

#include <pybind11/cast.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -174,6 +173,12 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
py::arg("obj"),
py::pos_only());

#if PYBIND11_VERSION_HEX >= 0x020E00F0 // pybind11 2.14.0
#define def_method_pos_only(...) def(__VA_ARGS__ __VA_OPT__(, ) py::pos_only())
#else
#define def_method_pos_only(...) def(__VA_ARGS__)
#endif

auto PyTreeKindTypeObject =
py::enum_<PyTreeKind>(mod, "PyTreeKind", "The kind of a pytree node.", py::module_local())
.value("CUSTOM", PyTreeKind::Custom, "A custom type.")
Expand Down Expand Up @@ -237,17 +242,23 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
py::arg("f_node"),
py::arg("f_leaf"),
py::arg("leaves"))
.def("paths", &PyTreeSpec::Paths, "Return a list of paths to the leaves of the treespec.")
.def("accessors",
&PyTreeSpec::Accessors,
"Return a list of accessors to the leaves in the treespec.")
.def("entries", &PyTreeSpec::Entries, "Return a list of one-level entries to the children.")
.def_method_pos_only("paths",
&PyTreeSpec::Paths,
"Return a list of paths to the leaves of the treespec.")
.def_method_pos_only("accessors",
&PyTreeSpec::Accessors,
"Return a list of accessors to the leaves in the treespec.")
.def_method_pos_only("entries",
&PyTreeSpec::Entries,
"Return a list of one-level entries to the children.")
.def("entry",
&PyTreeSpec::Entry,
"Return the entry at the given index.",
py::arg("index"),
py::pos_only())
.def("children", &PyTreeSpec::Children, "Return a list of treespecs for the children.")
.def_method_pos_only("children",
&PyTreeSpec::Children,
"Return a list of treespecs for the children.")
.def("child",
&PyTreeSpec::Child,
"Return the treespec for the child at the given index.",
Expand Down Expand Up @@ -332,9 +343,11 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
py::is_operator(),
py::arg("other"),
py::pos_only())
.def("__repr__", &PyTreeSpec::ToString, "Return a string representation of the treespec.")
.def("__hash__", &PyTreeSpec::HashValue, "Return the hash of the treespec.")
.def("__len__", &PyTreeSpec::GetNumLeaves, "Number of leaves in the tree.")
.def_method_pos_only("__repr__",
&PyTreeSpec::ToString,
"Return a string representation of the treespec.")
.def_method_pos_only("__hash__", &PyTreeSpec::HashValue, "Return the hash of the treespec.")
.def_method_pos_only("__len__", &PyTreeSpec::GetNumLeaves, "Number of leaves in the tree.")
.def(py::pickle([](const PyTreeSpec& t) -> py::object { return t.ToPickleable(); },
[](const py::object& o) -> std::unique_ptr<PyTreeSpec> {
return PyTreeSpec::FromPickleable(o);
Expand Down Expand Up @@ -367,8 +380,10 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references]
py::arg("leaf_predicate") = std::nullopt,
py::arg("none_is_leaf") = false,
py::arg("namespace") = "")
.def("__iter__", &PyTreeIter::Iter, "Return the iterator object itself.")
.def("__next__", &PyTreeIter::Next, "Return the next leaf in the pytree.");
.def_method_pos_only("__iter__", &PyTreeIter::Iter, "Return the iterator object itself.")
.def_method_pos_only("__next__", &PyTreeIter::Next, "Return the next leaf in the pytree.");

#undef def_method_pos_only

#ifdef Py_TPFLAGS_IMMUTABLETYPE
PyTreeKind_Type->tp_flags |= Py_TPFLAGS_IMMUTABLETYPE;
Expand Down

0 comments on commit ae89038

Please sign in to comment.