Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyOV] Restrict changing data in const #27431

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/bindings/python/src/pyopenvino/core/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,14 @@ py::array array_from_constant_copy(ov::op::v0::Constant&& c, py::dtype& dst_dtyp
py::array array_from_constant_view(ov::op::v0::Constant&& c) {
const auto& ov_type = c.get_element_type();
const auto dtype = Common::type_helpers::get_dtype(ov_type);
py::array data;
if (ov_type.bitwidth() < Common::values::min_bitwidth) {
return py::array(dtype, c.get_byte_size(), c.get_data_ptr(), py::cast(c));
data = py::array(dtype, c.get_byte_size(), c.get_data_ptr(), py::cast(c));
} else {
data = py::array(dtype, c.get_shape(), constant_helpers::_get_strides(c), c.get_data_ptr(), py::cast(c));
}
return py::array(dtype, c.get_shape(), constant_helpers::_get_strides(c), c.get_data_ptr(), py::cast(c));
data.attr("flags").attr("writeable") = false;
return data;
}

}; // namespace array_helpers
Expand Down
116 changes: 0 additions & 116 deletions src/bindings/python/tests/test_graph/test_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,55 +205,6 @@ def test_init_with_scalar(init_value, src_dtype, dst_dtype, shared_flag, data_ge
assert np.allclose(const_data, expected_result)


@pytest.mark.parametrize(
("src_dtype"),
[
(np.float16),
(np.uint16),
],
)
@pytest.mark.parametrize(
("shared_flag"),
[
(True),
(False),
],
)
@pytest.mark.parametrize(
("data_getter"),
[
(DataGetter.COPY),
(DataGetter.VIEW),
],
)
def test_init_bf16_populate(src_dtype, shared_flag, data_getter):
data = np.random.rand(1, 2, 16, 8) + 0.5
data = data.astype(src_dtype)

# To create bf16 constant, allocate memory and populate it:
init_data = np.zeros(shape=data.shape, dtype=src_dtype)
ov_const = ops.constant(init_data, dtype=Type.bf16, shared_memory=shared_flag)
ov_const.data[:] = data
akuporos marked this conversation as resolved.
Show resolved Hide resolved

# Check shape and element type of Constant class
assert isinstance(ov_const, Constant)
assert np.all(list(ov_const.shape) == [1, 2, 16, 8])
assert ov_const.get_element_type() == Type.bf16

_dst_dtype = Type.bf16.to_dtype()

assert ov_const.get_element_type().to_dtype() == _dst_dtype
# Compare values to Constant
if data_getter == DataGetter.COPY:
const_data = ov_const.get_data()
elif data_getter == DataGetter.VIEW:
const_data = ov_const.data
else:
raise AttributeError("Unknown DataGetter passed!")
assert const_data.dtype == _dst_dtype
assert np.allclose(const_data, data)


@pytest.mark.parametrize(
("ov_type", "numpy_dtype"),
[
Expand Down Expand Up @@ -286,58 +237,6 @@ def test_init_bf16_direct(ov_type, numpy_dtype, shared_flag):
assert np.allclose(data, result, rtol=0.01)


@pytest.mark.parametrize(
"shape",
[
([1, 3, 28, 28]),
([1, 3, 27, 27]),
],
)
@pytest.mark.parametrize(
("low", "high", "ov_type", "src_dtype"),
[
(0, 2, Type.u1, np.uint8),
(0, 16, Type.u4, np.uint8),
(-8, 7, Type.i4, np.int8),
(0, 16, Type.nf4, np.uint8),
],
)
@pytest.mark.parametrize(
("shared_flag"),
[
(True),
(False),
],
)
@pytest.mark.parametrize(
("data_getter"),
[
(DataGetter.COPY),
(DataGetter.VIEW),
],
)
def test_constant_helper_packing(shape, low, high, ov_type, src_dtype, shared_flag, data_getter):
data = np.random.uniform(low, high, shape).astype(src_dtype)

# Allocate memory first:
ov_const = ops.constant(np.zeros(shape=data.shape, dtype=src_dtype),
dtype=ov_type,
shared_memory=shared_flag)
# Fill data with packed values
packed_data = pack_data(data, ov_const.get_element_type())
ov_const.data[:] = packed_data

# Always unpack the data!
if data_getter == DataGetter.COPY:
unpacked = unpack_data(ov_const.get_data(), ov_const.get_element_type(), ov_const.shape)
elif data_getter == DataGetter.VIEW:
unpacked = unpack_data(ov_const.data, ov_const.get_element_type(), ov_const.shape)
else:
raise AttributeError("Unknown DataGetter passed!")

assert np.array_equal(unpacked, data)


@pytest.mark.parametrize(
("ov_type", "src_dtype"),
[
Expand Down Expand Up @@ -380,21 +279,6 @@ def test_constant_direct_packing(ov_type, src_dtype, shared_flag, data_getter):
assert not np.shares_memory(unpacked, data)


@pytest.mark.parametrize(
("shared_flag"),
[
(True),
(False),
],
)
def test_write_to_buffer(shared_flag):
arr_0 = np.ones([1, 3, 32, 32])
ov_const = ops.constant(arr_0, shared_memory=shared_flag)
arr_1 = np.ones([1, 3, 32, 32]) + 1
ov_const.data[:] = arr_1
assert np.array_equal(ov_const.data, arr_1)


@pytest.mark.parametrize(
("shared_flag"),
[
Expand Down
Loading