diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 2296f6f629c..d5589055e0e 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -35,6 +35,7 @@ namespace accessor_policies { struct sequence_item; struct list_item; struct tuple_item; + struct dict_item; } // namespace accessor_policies using obj_attr_accessor = accessor; using str_attr_accessor = accessor; @@ -42,6 +43,7 @@ using item_accessor = accessor; using sequence_accessor = accessor; using list_accessor = accessor; using tuple_accessor = accessor; +using dict_accessor = accessor; /// Tag and check to identify a class which implements the Python object API class pyobject_tag { }; @@ -613,6 +615,31 @@ struct tuple_item { } } }; + +struct dict_item { + using key_type = object; + + static object get(handle obj, handle key) { +#if PY_MAJOR_VERSION >= 3 + if (PyObject *result = PyDict_GetItemWithError(obj.ptr(), key.ptr())) { + return reinterpret_borrow(result); + } else { + if (!PyErr_Occurred()) + if (PyObject* key_repr = PyObject_Repr(key.ptr())) + PyErr_SetObject(PyExc_KeyError, key_repr); + throw error_already_set(); + } +#else + return generic_item::get(obj, key); +#endif + } + + static void set(handle obj, handle key, handle val) { + if (PyDict_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { + throw error_already_set(); + } + } +}; PYBIND11_NAMESPACE_END(accessor_policies) /// STL iterator template used for tuple, list, sequence and dict @@ -1285,6 +1312,8 @@ class dict : public object { size_t size() const { return (size_t) PyDict_Size(m_ptr); } bool empty() const { return size() == 0; } + detail::dict_accessor operator[](const char *key) const { return {*this, pybind11::str(key)}; } + detail::dict_accessor operator[](handle h) const { return {*this, reinterpret_borrow(h)}; } detail::dict_iterator begin() const { return {*this, 0}; } detail::dict_iterator end() const { return {}; } void clear() const { PyDict_Clear(ptr()); } @@ -1293,19 +1322,31 @@ class dict : public object { } object get(handle key, handle default_ = none()) const { - if (PyObject *result = PyDict_GetItem(m_ptr, key.ptr())) { +#if PY_MAJOR_VERSION >= 3 + if (PyObject *result = PyDict_GetItemWithError(m_ptr, key.ptr())) { return reinterpret_borrow(result); } else { - return reinterpret_borrow(default_); + if (PyErr_Occurred()) + throw error_already_set(); + else + return reinterpret_borrow(default_); } +#else + try { + return generic_item::get(*this, key); + } catch (...) { + if (PyErr_Occurred() == PyExc_KeyError) { + PyErr_Clear(); + return reinterpret_borrow(default_); + } else { + throw; + } + } +#endif } object get(const char *key, handle default_ = none()) const { - if (PyObject *result = PyDict_GetItemString(m_ptr, key)) { - return reinterpret_borrow(result); - } else { - return reinterpret_borrow(default_); - } + return get(pybind11::str(key), default_); } private: diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 2a468a49888..fd4e61132c5 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -68,22 +68,22 @@ TEST_SUBMODULE(pytypes, m) { auto d2 = py::dict("z"_a=3, **d1); return d2; }); - m.def("dict_contains", [](py::dict dict, py::object val) { + m.def("dict_contains", [](py::dict dict, const char* val) { return dict.contains(val); }); - m.def("dict_contains", [](py::dict dict, const char* val) { + m.def("dict_contains", [](py::dict dict, py::object val) { return dict.contains(val); }); - m.def("dict_get", [](py::dict dict, py::object key, py::object default_) { + m.def("dict_get", [](py::dict dict, const char* key, py::object default_) { return dict.get(key, default_); }); - m.def("dict_get", [](py::dict dict, const char* key, py::object default_) { + m.def("dict_get", [](py::dict dict, py::object key, py::object default_) { return dict.get(key, default_); }); - m.def("dict_get", [](py::dict dict, py::object key) { + m.def("dict_get", [](py::dict dict, const char* key) { return dict.get(key); }); - m.def("dict_get", [](py::dict dict, const char* key) { + m.def("dict_get", [](py::dict dict, py::object key) { return dict.get(key); });