diff --git a/include/oneapi/math/rng/device/detail/distribution_base.hpp b/include/oneapi/math/rng/device/detail/distribution_base.hpp index 4faf6cd49..9a76dd3dd 100644 --- a/include/oneapi/math/rng/device/detail/distribution_base.hpp +++ b/include/oneapi/math/rng/device/detail/distribution_base.hpp @@ -65,6 +65,9 @@ class poisson; template class bernoulli; +template +class geometric; + } // namespace oneapi::math::rng::device #include "oneapi/math/rng/device/detail/uniform_impl.hpp" @@ -75,6 +78,7 @@ class bernoulli; #include "oneapi/math/rng/device/detail/exponential_impl.hpp" #include "oneapi/math/rng/device/detail/poisson_impl.hpp" #include "oneapi/math/rng/device/detail/bernoulli_impl.hpp" +#include "oneapi/math/rng/device/detail/geometric_impl.hpp" #include "oneapi/math/rng/device/detail/beta_impl.hpp" #include "oneapi/math/rng/device/detail/gamma_impl.hpp" diff --git a/include/oneapi/math/rng/device/detail/geometric_impl.hpp b/include/oneapi/math/rng/device/detail/geometric_impl.hpp new file mode 100644 index 000000000..3466db8b7 --- /dev/null +++ b/include/oneapi/math/rng/device/detail/geometric_impl.hpp @@ -0,0 +1,99 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_ +#define ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_ + +namespace oneapi::math::rng::device::detail { + +template +class distribution_base> { +public: + struct param_type { + param_type(float p) : p_(p) {} + float p_; + }; + + distribution_base(float p) : p_(p) { +#ifndef __SYCL_DEVICE_ONLY__ + if ((p > 1.0f) || (p < 0.0f)) { + throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1"); + } +#endif + } + + float p() const { + return p_; + } + + param_type param() const { + return param_type(p_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if ((pt.p_ > 1.0f) || (pt.p_ < 0.0f)) { + throw oneapi::math::invalid_argument("rng", "geometric", "p < 0 || p > 1"); + } +#endif + p_ = pt.p_; + } + +protected: + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + using FpType = typename std::conditional || + std::is_same_v, + double, float>::type; + + auto uni_res = engine.generate(FpType(0.0), FpType(1.0)); + FpType inv_ln = ln_wrapper(FpType(1.0) - p_); + inv_ln = FpType(1.0) / inv_ln; + if constexpr (EngineType::vec_size == 1) { + return static_cast(sycl::floor(ln_wrapper(uni_res) * inv_ln)); + } + else { + sycl::vec vec_out; + for (int i = 0; i < EngineType::vec_size; i++) { + vec_out[i] = static_cast(sycl::floor(ln_wrapper(uni_res[i]) * inv_ln)); + } + return vec_out; + } + } + + template + IntType generate_single(EngineType& engine) { + using FpType = typename std::conditional || + std::is_same_v, + double, float>::type; + + FpType uni_res = engine.generate_single(FpType(0.0), FpType(1.0)); + FpType inv_ln = ln_wrapper(FpType(1.0) - p_); + inv_ln = FpType(1.0) / inv_ln; + return static_cast(sycl::floor(ln_wrapper(uni_res) * inv_ln)); + } + + float p_; +}; + +} // namespace oneapi::math::rng::device::detail + +#endif // ONEMATH_RNG_DEVICE_GEOMETRIC_IMPL_HPP_ diff --git a/include/oneapi/math/rng/device/distributions.hpp b/include/oneapi/math/rng/device/distributions.hpp index 6def09d45..a7c468eac 100644 --- a/include/oneapi/math/rng/device/distributions.hpp +++ b/include/oneapi/math/rng/device/distributions.hpp @@ -632,6 +632,64 @@ class bernoulli : detail::distribution_base> { friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); }; +// Class template oneapi::math::rng::device::geometric +// +// Represents discrete geometric random number distribution +// +// Supported types: +// std::uint32_t +// std::int32_t +// std::uint64_t +// std::int64_t +// +// Supported methods: +// oneapi::math::rng::geometric_method::icdf; +// +// Input arguments: +// p - success probablity of a trial. 0.5 by default +// +template +class geometric : detail::distribution_base> { +public: + static_assert(std::is_same::value, + "oneMath: rng/geometric: method is incorrect"); + + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, + "oneMath: rng/geometric: type is not supported"); + + using method_type = Method; + using result_type = IntType; + using param_type = typename detail::distribution_base>::param_type; + + geometric() : detail::distribution_base>(0.5f) {} + + explicit geometric(float p) : detail::distribution_base>(p) {} + explicit geometric(const param_type& pt) + : detail::distribution_base>(pt.p_) {} + + float p() const { + return detail::distribution_base>::p(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + } // namespace oneapi::math::rng::device #endif // ONEMATH_RNG_DEVICE_DISTRIBUTIONS_HPP_ diff --git a/include/oneapi/math/rng/device/types.hpp b/include/oneapi/math/rng/device/types.hpp index d2cb7ac37..158079e58 100644 --- a/include/oneapi/math/rng/device/types.hpp +++ b/include/oneapi/math/rng/device/types.hpp @@ -57,6 +57,11 @@ struct icdf {}; using by_default = icdf; } // namespace bernoulli_method +namespace geometric_method { +struct icdf {}; +using by_default = icdf; +} // namespace geometric_method + namespace beta_method { struct cja {}; struct cja_accurate {}; diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp index 5e373e2cf..1bc954e49 100644 --- a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -352,6 +352,22 @@ struct statistics_device> { } }; +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::math::rng::device::geometric& distr) { + double tM, tD, tQ; + double p = static_cast(distr.p()); + + tM = (1.0 - p) / p; + tD = (1.0 - p) / (p * p); + tQ = (1.0 - p) * (p * p - 9.0 * p + 9.0) / (p * p * p * p); + + return compare_moments(r, tM, tD, tQ); + } +}; + template struct statistics_device> { template diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp index a191b67df..110ccbb0b 100644 --- a/tests/unit_tests/rng/device/moments/moments.cpp +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -1416,4 +1416,99 @@ INSTANTIATE_TEST_SUITE_P(Philox4x32x10BernoulliIcdfDeviceMomentsTestsSuite, Philox4x32x10BernoulliIcdfDeviceMomentsTests, ::testing::ValuesIn(devices), ::DeviceNamePrint()); +class Philox4x32x10GeometricIcdfDeviceMomentsTests + : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint32_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, Integer64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::int64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10GeometricIcdfDeviceMomentsTests, UnsignedInteger64Precision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::math::rng::device::geometric< + std::uint64_t, oneapi::math::rng::device::geometric_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GeometricIcdfDeviceMomentsTestsSuite, + Philox4x32x10GeometricIcdfDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + } // anonymous namespace