Skip to content

Commit

Permalink
dict_accessor for optimized operator [] of dict
Browse files Browse the repository at this point in the history
  • Loading branch information
lqf96 committed Jan 11, 2021
1 parent e0aa141 commit 24ad818
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
55 changes: 48 additions & 7 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ namespace accessor_policies {
struct sequence_item;
struct list_item;
struct tuple_item;
struct dict_item;
} // namespace accessor_policies
using obj_attr_accessor = accessor<accessor_policies::obj_attr>;
using str_attr_accessor = accessor<accessor_policies::str_attr>;
using item_accessor = accessor<accessor_policies::generic_item>;
using sequence_accessor = accessor<accessor_policies::sequence_item>;
using list_accessor = accessor<accessor_policies::list_item>;
using tuple_accessor = accessor<accessor_policies::tuple_item>;
using dict_accessor = accessor<accessor_policies::dict_item>;

/// Tag and check to identify a class which implements the Python object API
class pyobject_tag { };
Expand Down Expand Up @@ -613,6 +615,31 @@ struct tuple_item {
}
}
};

struct dict_item {
using key_type = object;

static object get(handle obj, handle key) {
#if PY_MAJOR_VERSION >= 3
if (PyObject *result = PyDict_GetItemWithError(obj.ptr(), key.ptr())) {
return reinterpret_borrow<object>(result);
} else {
if (!PyErr_Occurred())
if (PyObject* key_repr = PyObject_Repr(key.ptr()))
PyErr_SetObject(PyExc_KeyError, key_repr);
throw error_already_set();
}
#else
return generic_item::get(obj, key);
#endif
}

static void set(handle obj, handle key, handle val) {
if (PyDict_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) {
throw error_already_set();
}
}
};
PYBIND11_NAMESPACE_END(accessor_policies)

/// STL iterator template used for tuple, list, sequence and dict
Expand Down Expand Up @@ -1285,6 +1312,8 @@ class dict : public object {

size_t size() const { return (size_t) PyDict_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::dict_accessor operator[](const char *key) const { return {*this, pybind11::str(key)}; }
detail::dict_accessor operator[](handle h) const { return {*this, reinterpret_borrow<object>(h)}; }
detail::dict_iterator begin() const { return {*this, 0}; }
detail::dict_iterator end() const { return {}; }
void clear() const { PyDict_Clear(ptr()); }
Expand All @@ -1293,19 +1322,31 @@ class dict : public object {
}

object get(handle key, handle default_ = none()) const {
if (PyObject *result = PyDict_GetItem(m_ptr, key.ptr())) {
#if PY_MAJOR_VERSION >= 3
if (PyObject *result = PyDict_GetItemWithError(m_ptr, key.ptr())) {
return reinterpret_borrow<object>(result);
} else {
return reinterpret_borrow<object>(default_);
if (PyErr_Occurred())
throw error_already_set();
else
return reinterpret_borrow<object>(default_);
}
#else
try {
return generic_item::get(*this, key);
} catch (...) {
if (PyErr_Occurred() == PyExc_KeyError) {
PyErr_Clear();
return reinterpret_borrow<object>(default_);
} else {
throw;
}
}
#endif
}

object get(const char *key, handle default_ = none()) const {
if (PyObject *result = PyDict_GetItemString(m_ptr, key)) {
return reinterpret_borrow<object>(result);
} else {
return reinterpret_borrow<object>(default_);
}
return get(pybind11::str(key), default_);
}

private:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,22 @@ TEST_SUBMODULE(pytypes, m) {
auto d2 = py::dict("z"_a=3, **d1);
return d2;
});
m.def("dict_contains", [](py::dict dict, py::object val) {
m.def("dict_contains", [](py::dict dict, const char* val) {
return dict.contains(val);
});
m.def("dict_contains", [](py::dict dict, const char* val) {
m.def("dict_contains", [](py::dict dict, py::object val) {
return dict.contains(val);
});
m.def("dict_get", [](py::dict dict, py::object key, py::object default_) {
m.def("dict_get", [](py::dict dict, const char* key, py::object default_) {
return dict.get(key, default_);
});
m.def("dict_get", [](py::dict dict, const char* key, py::object default_) {
m.def("dict_get", [](py::dict dict, py::object key, py::object default_) {
return dict.get(key, default_);
});
m.def("dict_get", [](py::dict dict, py::object key) {
m.def("dict_get", [](py::dict dict, const char* key) {
return dict.get(key);
});
m.def("dict_get", [](py::dict dict, const char* key) {
m.def("dict_get", [](py::dict dict, py::object key) {
return dict.get(key);
});

Expand Down

0 comments on commit 24ad818

Please sign in to comment.