From 31e8e195c9d4436690f8c34384984e71e8c69d7e Mon Sep 17 00:00:00 2001 From: jatin Date: Mon, 13 Nov 2023 17:53:45 +0000 Subject: [PATCH] Initial commit --- .clang-format | 75 ++ .gitignore | 4 + CMakeLists.txt | 28 + README.md | 5 + cmake/CPM.cmake | 1116 ++++++++++++++++++++ include/math_approx/math_approx.hpp | 12 + include/math_approx/src/basic_math.hpp | 43 + include/math_approx/src/sigmoid_approx.hpp | 76 ++ include/math_approx/src/tanh_approx.hpp | 81 ++ test/CMakeLists.txt | 32 + test/src/sigmoid_approx_test.cpp | 48 + test/src/tanh_approx_test.cpp | 72 ++ test/src/test_helpers.hpp | 124 +++ tools/CMakeLists.txt | 2 + tools/bench/CMakeLists.txt | 12 + tools/bench/sigmoid_bench.cpp | 51 + tools/bench/tanh_bench.cpp | 53 + tools/plotter/CMakeLists.txt | 8 + tools/plotter/plotter.cpp | 80 ++ 19 files changed, 1922 insertions(+) create mode 100644 .clang-format create mode 100644 .gitignore create mode 100644 CMakeLists.txt create mode 100644 README.md create mode 100644 cmake/CPM.cmake create mode 100644 include/math_approx/math_approx.hpp create mode 100644 include/math_approx/src/basic_math.hpp create mode 100644 include/math_approx/src/sigmoid_approx.hpp create mode 100644 include/math_approx/src/tanh_approx.hpp create mode 100644 test/CMakeLists.txt create mode 100644 test/src/sigmoid_approx_test.cpp create mode 100644 test/src/tanh_approx_test.cpp create mode 100644 test/src/test_helpers.hpp create mode 100644 tools/CMakeLists.txt create mode 100644 tools/bench/CMakeLists.txt create mode 100644 tools/bench/sigmoid_bench.cpp create mode 100644 tools/bench/tanh_bench.cpp create mode 100644 tools/plotter/CMakeLists.txt create mode 100644 tools/plotter/plotter.cpp diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..4483fbe --- /dev/null +++ b/.clang-format @@ -0,0 +1,75 @@ +--- +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: Align +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakAfterJavaFieldAnnotations: false +BreakBeforeBinaryOperators: NonAssignment +BreakBeforeBraces: Allman +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakStringLiterals: false +ColumnLimit: 0 +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: false +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +ForEachMacros: ['forEachXmlChildElement'] +IndentCaseLabels: true +IndentWidth: 4 +IndentWrappedFunctionNames: true +KeepEmptyLinesAtTheStartOfBlocks: false +Language: Cpp +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: Inner +PointerAlignment: Left +ReflowComments: false +SortIncludes: true +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: true +SpaceBeforeParens: NonEmptyParentheses +SpaceInEmptyParentheses: false +SpaceBeforeInheritanceColon: true +SpacesInAngles: false +SpacesInCStyleCastParentheses: false +SpacesInContainerLiterals: true +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: "c++17" +TabWidth: 4 +UseTab: Never +--- +Language: ObjC +BasedOnStyle: Chromium +AlignTrailingComments: true +BreakBeforeBraces: Allman +ColumnLimit: 0 +IndentWidth: 4 +KeepEmptyLinesAtTheStartOfBlocks: false +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PointerAlignment: Left +SpacesBeforeTrailingComments: 1 +TabWidth: 4 +UseTab: Never +... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..98b57a2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.idea/ +.vscode/ + +build*/ diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..864a1d5 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.21) +project(math_approx) +set(CMAKE_CXX_STANDARD 20) + +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_SOURCE_DIR}/cmake/) + +if(PROJECT_IS_TOP_LEVEL) + include(CPM) + CPMAddPackage( + NAME xsimd + GIT_REPOSITORY https://github.com/xtensor-stack/xsimd + GIT_TAG master + ) +endif() + +add_library(math_approx INTERFACE) +target_include_directories(math_approx INTERFACE include) +if (TARGET xsimd) + message(STATUS "math_approx -- Linking with XSIMD...") + target_link_libraries(math_approx INTERFACE xsimd) + target_compile_definitions(math_approx INTERFACE MATH_APPROX_XSIMD_TARGET=1) +endif() + +if(PROJECT_IS_TOP_LEVEL) + include(CTest) + add_subdirectory(test) + add_subdirectory(tools) +endif() diff --git a/README.md b/README.md new file mode 100644 index 0000000..bcd86e7 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# math_approx + +A C++ library for math approximations. + +More info coming soon! diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake new file mode 100644 index 0000000..464634a --- /dev/null +++ b/cmake/CPM.cmake @@ -0,0 +1,1116 @@ +# CPM.cmake - CMake's missing package manager +# =========================================== +# See https://github.com/cpm-cmake/CPM.cmake for usage and update instructions. +# +# MIT License +# ----------- +#[[ + Copyright (c) 2019-2022 Lars Melchior and contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +]] + +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +# Initialize logging prefix +if(NOT CPM_INDENT) + set(CPM_INDENT + "CPM:" + CACHE INTERNAL "" + ) +endif() + +if(NOT COMMAND cpm_message) + function(cpm_message) + message(${ARGV}) + endfunction() +endif() + +set(CURRENT_CPM_VERSION 1.0.0-development-version) + +get_filename_component(CPM_CURRENT_DIRECTORY "${CMAKE_CURRENT_LIST_DIR}" REALPATH) +if(CPM_DIRECTORY) + if(NOT CPM_DIRECTORY STREQUAL CPM_CURRENT_DIRECTORY) + if(CPM_VERSION VERSION_LESS CURRENT_CPM_VERSION) + message( + AUTHOR_WARNING + "${CPM_INDENT} \ +A dependency is using a more recent CPM version (${CURRENT_CPM_VERSION}) than the current project (${CPM_VERSION}). \ +It is recommended to upgrade CPM to the most recent version. \ +See https://github.com/cpm-cmake/CPM.cmake for more information." + ) + endif() + if(${CMAKE_VERSION} VERSION_LESS "3.17.0") + include(FetchContent) + endif() + return() + endif() + + get_property( + CPM_INITIALIZED GLOBAL "" + PROPERTY CPM_INITIALIZED + SET + ) + if(CPM_INITIALIZED) + return() + endif() +endif() + +set_property(GLOBAL PROPERTY CPM_INITIALIZED true) + +macro(cpm_set_policies) + # the policy allows us to change options without caching + cmake_policy(SET CMP0077 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + # the policy allows us to change set(CACHE) without caching + if(POLICY CMP0126) + cmake_policy(SET CMP0126 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0126 NEW) + endif() + + # The policy uses the download time for timestamp, instead of the timestamp in the archive. This + # allows for proper rebuilds when a projects url changes + if(POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) + set(CMAKE_POLICY_DEFAULT_CMP0135 NEW) + endif() +endmacro() +cpm_set_policies() + +option(CPM_USE_LOCAL_PACKAGES "Always try to use `find_package` to get dependencies" + $ENV{CPM_USE_LOCAL_PACKAGES} + ) +option(CPM_LOCAL_PACKAGES_ONLY "Only use `find_package` to get dependencies" + $ENV{CPM_LOCAL_PACKAGES_ONLY} + ) +option(CPM_DOWNLOAD_ALL "Always download dependencies from source" $ENV{CPM_DOWNLOAD_ALL}) +option(CPM_DONT_UPDATE_MODULE_PATH "Don't update the module path to allow using find_package" + $ENV{CPM_DONT_UPDATE_MODULE_PATH} + ) +option(CPM_DONT_CREATE_PACKAGE_LOCK "Don't create a package lock file in the binary path" + $ENV{CPM_DONT_CREATE_PACKAGE_LOCK} + ) +option(CPM_INCLUDE_ALL_IN_PACKAGE_LOCK + "Add all packages added through CPM.cmake to the package lock" + $ENV{CPM_INCLUDE_ALL_IN_PACKAGE_LOCK} + ) +option(CPM_USE_NAMED_CACHE_DIRECTORIES + "Use additional directory of package name in cache on the most nested level." + $ENV{CPM_USE_NAMED_CACHE_DIRECTORIES} + ) + +set(CPM_VERSION + ${CURRENT_CPM_VERSION} + CACHE INTERNAL "" + ) +set(CPM_DIRECTORY + ${CPM_CURRENT_DIRECTORY} + CACHE INTERNAL "" + ) +set(CPM_FILE + ${CMAKE_CURRENT_LIST_FILE} + CACHE INTERNAL "" + ) +set(CPM_PACKAGES + "" + CACHE INTERNAL "" + ) +set(CPM_DRY_RUN + OFF + CACHE INTERNAL "Don't download or configure dependencies (for testing)" + ) + +if(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_SOURCE_CACHE_DEFAULT $ENV{CPM_SOURCE_CACHE}) +else() + set(CPM_SOURCE_CACHE_DEFAULT OFF) +endif() + +set(CPM_SOURCE_CACHE + ${CPM_SOURCE_CACHE_DEFAULT} + CACHE PATH "Directory to download CPM dependencies" + ) + +if(NOT CPM_DONT_UPDATE_MODULE_PATH) + set(CPM_MODULE_PATH + "${CMAKE_BINARY_DIR}/CPM_modules" + CACHE INTERNAL "" + ) + # remove old modules + file(REMOVE_RECURSE ${CPM_MODULE_PATH}) + file(MAKE_DIRECTORY ${CPM_MODULE_PATH}) + # locally added CPM modules should override global packages + set(CMAKE_MODULE_PATH "${CPM_MODULE_PATH};${CMAKE_MODULE_PATH}") +endif() + +if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + set(CPM_PACKAGE_LOCK_FILE + "${CMAKE_BINARY_DIR}/cpm-package-lock.cmake" + CACHE INTERNAL "" + ) + file(WRITE ${CPM_PACKAGE_LOCK_FILE} + "# CPM Package Lock\n# This file should be committed to version control\n\n" + ) +endif() + +include(FetchContent) + +# Try to infer package name from git repository uri (path or url) +function(cpm_package_name_from_git_uri URI RESULT) + if("${URI}" MATCHES "([^/:]+)/?.git/?$") + set(${RESULT} + ${CMAKE_MATCH_1} + PARENT_SCOPE + ) + else() + unset(${RESULT} PARENT_SCOPE) + endif() +endfunction() + +# Try to infer package name and version from a url +function(cpm_package_name_and_ver_from_url url outName outVer) + if(url MATCHES "[/\\?]([a-zA-Z0-9_\\.-]+)\\.(tar|tar\\.gz|tar\\.bz2|zip|ZIP)(\\?|/|$)") + # We matched an archive + set(filename "${CMAKE_MATCH_1}") + + if(filename MATCHES "([a-zA-Z0-9_\\.-]+)[_-]v?(([0-9]+\\.)*[0-9]+[a-zA-Z0-9]*)") + # We matched - (ie foo-1.2.3) + set(${outName} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + set(${outVer} + "${CMAKE_MATCH_2}" + PARENT_SCOPE + ) + elseif(filename MATCHES "(([0-9]+\\.)+[0-9]+[a-zA-Z0-9]*)") + # We couldn't find a name, but we found a version + # + # In many cases (which we don't handle here) the url would look something like + # `irrelevant/ACTUAL_PACKAGE_NAME/irrelevant/1.2.3.zip`. In such a case we can't possibly + # distinguish the package name from the irrelevant bits. Moreover if we try to match the + # package name from the filename, we'd get bogus at best. + unset(${outName} PARENT_SCOPE) + set(${outVer} + "${CMAKE_MATCH_1}" + PARENT_SCOPE + ) + else() + # Boldly assume that the file name is the package name. + # + # Yes, something like `irrelevant/ACTUAL_NAME/irrelevant/download.zip` will ruin our day, but + # such cases should be quite rare. No popular service does this... we think. + set(${outName} + "${filename}" + PARENT_SCOPE + ) + unset(${outVer} PARENT_SCOPE) + endif() + else() + # No ideas yet what to do with non-archives + unset(${outName} PARENT_SCOPE) + unset(${outVer} PARENT_SCOPE) + endif() +endfunction() + +function(cpm_find_package NAME VERSION) + string(REPLACE " " ";" EXTRA_ARGS "${ARGN}") + find_package(${NAME} ${VERSION} ${EXTRA_ARGS} QUIET) + if(${CPM_ARGS_NAME}_FOUND) + if(DEFINED ${CPM_ARGS_NAME}_VERSION) + set(VERSION ${${CPM_ARGS_NAME}_VERSION}) + endif() + cpm_message(STATUS "${CPM_INDENT} Using local package ${CPM_ARGS_NAME}@${VERSION}") + CPMRegisterPackage(${CPM_ARGS_NAME} "${VERSION}") + set(CPM_PACKAGE_FOUND + YES + PARENT_SCOPE + ) + else() + set(CPM_PACKAGE_FOUND + NO + PARENT_SCOPE + ) + endif() +endfunction() + +# Create a custom FindXXX.cmake module for a CPM package This prevents `find_package(NAME)` from +# finding the system library +function(cpm_create_module_file Name) + if(NOT CPM_DONT_UPDATE_MODULE_PATH) + # erase any previous modules + file(WRITE ${CPM_MODULE_PATH}/Find${Name}.cmake + "include(\"${CPM_FILE}\")\n${ARGN}\nset(${Name}_FOUND TRUE)" + ) + endif() +endfunction() + +# Find a package locally or fallback to CPMAddPackage +function(CPMFindPackage) + set(oneValueArgs NAME VERSION GIT_TAG FIND_PACKAGE_ARGUMENTS) + + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "" ${ARGN}) + + if(NOT DEFINED CPM_ARGS_VERSION) + if(DEFINED CPM_ARGS_GIT_TAG) + cpm_get_version_from_git_tag("${CPM_ARGS_GIT_TAG}" CPM_ARGS_VERSION) + endif() + endif() + + set(downloadPackage ${CPM_DOWNLOAD_ALL}) + if(DEFINED CPM_DOWNLOAD_${CPM_ARGS_NAME}) + set(downloadPackage ${CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + elseif(DEFINED ENV{CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + set(downloadPackage $ENV{CPM_DOWNLOAD_${CPM_ARGS_NAME}}) + endif() + if(downloadPackage) + CPMAddPackage(${ARGN}) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + if(CPM_PACKAGE_ALREADY_ADDED) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + cpm_find_package(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}" ${CPM_ARGS_FIND_PACKAGE_ARGUMENTS}) + + if(NOT CPM_PACKAGE_FOUND) + CPMAddPackage(${ARGN}) + cpm_export_variables(${CPM_ARGS_NAME}) + endif() + +endfunction() + +# checks if a package has been added before +function(cpm_check_if_package_already_added CPM_ARGS_NAME CPM_ARGS_VERSION) + if("${CPM_ARGS_NAME}" IN_LIST CPM_PACKAGES) + CPMGetPackageVersion(${CPM_ARGS_NAME} CPM_PACKAGE_VERSION) + if("${CPM_PACKAGE_VERSION}" VERSION_LESS "${CPM_ARGS_VERSION}") + message( + WARNING + "${CPM_INDENT} Requires a newer version of ${CPM_ARGS_NAME} (${CPM_ARGS_VERSION}) than currently included (${CPM_PACKAGE_VERSION})." + ) + endif() + cpm_get_fetch_properties(${CPM_ARGS_NAME}) + set(${CPM_ARGS_NAME}_ADDED NO) + set(CPM_PACKAGE_ALREADY_ADDED + YES + PARENT_SCOPE + ) + cpm_export_variables(${CPM_ARGS_NAME}) + else() + set(CPM_PACKAGE_ALREADY_ADDED + NO + PARENT_SCOPE + ) + endif() +endfunction() + +# Parse the argument of CPMAddPackage in case a single one was provided and convert it to a list of +# arguments which can then be parsed idiomatically. For example gh:foo/bar@1.2.3 will be converted +# to: GITHUB_REPOSITORY;foo/bar;VERSION;1.2.3 +function(cpm_parse_add_package_single_arg arg outArgs) + # Look for a scheme + if("${arg}" MATCHES "^([a-zA-Z]+):(.+)$") + string(TOLOWER "${CMAKE_MATCH_1}" scheme) + set(uri "${CMAKE_MATCH_2}") + + # Check for CPM-specific schemes + if(scheme STREQUAL "gh") + set(out "GITHUB_REPOSITORY;${uri}") + set(packageType "git") + elseif(scheme STREQUAL "gl") + set(out "GITLAB_REPOSITORY;${uri}") + set(packageType "git") + elseif(scheme STREQUAL "bb") + set(out "BITBUCKET_REPOSITORY;${uri}") + set(packageType "git") + # A CPM-specific scheme was not found. Looks like this is a generic URL so try to determine + # type + elseif(arg MATCHES ".git/?(@|#|$)") + set(out "GIT_REPOSITORY;${arg}") + set(packageType "git") + else() + # Fall back to a URL + set(out "URL;${arg}") + set(packageType "archive") + + # We could also check for SVN since FetchContent supports it, but SVN is so rare these days. + # We just won't bother with the additional complexity it will induce in this function. SVN is + # done by multi-arg + endif() + else() + if(arg MATCHES ".git/?(@|#|$)") + set(out "GIT_REPOSITORY;${arg}") + set(packageType "git") + else() + # Give up + message(FATAL_ERROR "${CPM_INDENT} Can't determine package type of '${arg}'") + endif() + endif() + + # For all packages we interpret @... as version. Only replace the last occurrence. Thus URIs + # containing '@' can be used + string(REGEX REPLACE "@([^@]+)$" ";VERSION;\\1" out "${out}") + + # Parse the rest according to package type + if(packageType STREQUAL "git") + # For git repos we interpret #... as a tag or branch or commit hash + string(REGEX REPLACE "#([^#]+)$" ";GIT_TAG;\\1" out "${out}") + elseif(packageType STREQUAL "archive") + # For archives we interpret #... as a URL hash. + string(REGEX REPLACE "#([^#]+)$" ";URL_HASH;\\1" out "${out}") + # We don't try to parse the version if it's not provided explicitly. cpm_get_version_from_url + # should do this at a later point + else() + # We should never get here. This is an assertion and hitting it means there's a bug in the code + # above. A packageType was set, but not handled by this if-else. + message(FATAL_ERROR "${CPM_INDENT} Unsupported package type '${packageType}' of '${arg}'") + endif() + + set(${outArgs} + ${out} + PARENT_SCOPE + ) +endfunction() + +# Check that the working directory for a git repo is clean +function(cpm_check_git_working_dir_is_clean repoPath gitTag isClean) + + find_package(Git REQUIRED) + + if(NOT GIT_EXECUTABLE) + # No git executable, assume directory is clean + set(${isClean} + TRUE + PARENT_SCOPE + ) + return() + endif() + + # check for uncommitted changes + execute_process( + COMMAND ${GIT_EXECUTABLE} status --porcelain + RESULT_VARIABLE resultGitStatus + OUTPUT_VARIABLE repoStatus + OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET + WORKING_DIRECTORY ${repoPath} + ) + if(resultGitStatus) + # not supposed to happen, assume clean anyway + message(WARNING "${CPM_INDENT} Calling git status on folder ${repoPath} failed") + set(${isClean} + TRUE + PARENT_SCOPE + ) + return() + endif() + + if(NOT "${repoStatus}" STREQUAL "") + set(${isClean} + FALSE + PARENT_SCOPE + ) + return() + endif() + + # check for committed changes + execute_process( + COMMAND ${GIT_EXECUTABLE} diff -s --exit-code ${gitTag} + RESULT_VARIABLE resultGitDiff + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_QUIET + WORKING_DIRECTORY ${repoPath} + ) + + if(${resultGitDiff} EQUAL 0) + set(${isClean} + TRUE + PARENT_SCOPE + ) + else() + set(${isClean} + FALSE + PARENT_SCOPE + ) + endif() + +endfunction() + +# method to overwrite internal FetchContent properties, to allow using CPM.cmake to overload +# FetchContent calls. As these are internal cmake properties, this method should be used carefully +# and may need modification in future CMake versions. Source: +# https://github.com/Kitware/CMake/blob/dc3d0b5a0a7d26d43d6cfeb511e224533b5d188f/Modules/FetchContent.cmake#L1152 +function(cpm_override_fetchcontent contentName) + cmake_parse_arguments(PARSE_ARGV 1 arg "" "SOURCE_DIR;BINARY_DIR" "") + if(NOT "${arg_UNPARSED_ARGUMENTS}" STREQUAL "") + message(FATAL_ERROR "${CPM_INDENT} Unsupported arguments: ${arg_UNPARSED_ARGUMENTS}") + endif() + + string(TOLOWER ${contentName} contentNameLower) + set(prefix "_FetchContent_${contentNameLower}") + + set(propertyName "${prefix}_sourceDir") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} "${arg_SOURCE_DIR}") + + set(propertyName "${prefix}_binaryDir") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} "${arg_BINARY_DIR}") + + set(propertyName "${prefix}_populated") + define_property( + GLOBAL + PROPERTY ${propertyName} + BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()" + FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}" + ) + set_property(GLOBAL PROPERTY ${propertyName} TRUE) +endfunction() + +# Download and add a package from source +function(CPMAddPackage) + cpm_set_policies() + + list(LENGTH ARGN argnLength) + if(argnLength EQUAL 1) + cpm_parse_add_package_single_arg("${ARGN}" ARGN) + + # The shorthand syntax implies EXCLUDE_FROM_ALL + set(ARGN "${ARGN};EXCLUDE_FROM_ALL;YES") + endif() + + set(oneValueArgs + NAME + FORCE + VERSION + GIT_TAG + DOWNLOAD_ONLY + GITHUB_REPOSITORY + GITLAB_REPOSITORY + BITBUCKET_REPOSITORY + GIT_REPOSITORY + SOURCE_DIR + DOWNLOAD_COMMAND + FIND_PACKAGE_ARGUMENTS + NO_CACHE + GIT_SHALLOW + EXCLUDE_FROM_ALL + SOURCE_SUBDIR + ) + + set(multiValueArgs URL OPTIONS) + + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "${multiValueArgs}" "${ARGN}") + + # Set default values for arguments + + if(NOT DEFINED CPM_ARGS_VERSION) + if(DEFINED CPM_ARGS_GIT_TAG) + cpm_get_version_from_git_tag("${CPM_ARGS_GIT_TAG}" CPM_ARGS_VERSION) + endif() + endif() + + if(CPM_ARGS_DOWNLOAD_ONLY) + set(DOWNLOAD_ONLY ${CPM_ARGS_DOWNLOAD_ONLY}) + else() + set(DOWNLOAD_ONLY NO) + endif() + + if(DEFINED CPM_ARGS_GITHUB_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://github.com/${CPM_ARGS_GITHUB_REPOSITORY}.git") + elseif(DEFINED CPM_ARGS_GITLAB_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://gitlab.com/${CPM_ARGS_GITLAB_REPOSITORY}.git") + elseif(DEFINED CPM_ARGS_BITBUCKET_REPOSITORY) + set(CPM_ARGS_GIT_REPOSITORY "https://bitbucket.org/${CPM_ARGS_BITBUCKET_REPOSITORY}.git") + endif() + + if(DEFINED CPM_ARGS_GIT_REPOSITORY) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_REPOSITORY ${CPM_ARGS_GIT_REPOSITORY}) + if(NOT DEFINED CPM_ARGS_GIT_TAG) + set(CPM_ARGS_GIT_TAG v${CPM_ARGS_VERSION}) + endif() + + # If a name wasn't provided, try to infer it from the git repo + if(NOT DEFINED CPM_ARGS_NAME) + cpm_package_name_from_git_uri(${CPM_ARGS_GIT_REPOSITORY} CPM_ARGS_NAME) + endif() + endif() + + set(CPM_SKIP_FETCH FALSE) + + if(DEFINED CPM_ARGS_GIT_TAG) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_TAG ${CPM_ARGS_GIT_TAG}) + # If GIT_SHALLOW is explicitly specified, honor the value. + if(DEFINED CPM_ARGS_GIT_SHALLOW) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_SHALLOW ${CPM_ARGS_GIT_SHALLOW}) + endif() + endif() + + if(DEFINED CPM_ARGS_URL) + # If a name or version aren't provided, try to infer them from the URL + list(GET CPM_ARGS_URL 0 firstUrl) + cpm_package_name_and_ver_from_url(${firstUrl} nameFromUrl verFromUrl) + # If we fail to obtain name and version from the first URL, we could try other URLs if any. + # However multiple URLs are expected to be quite rare, so for now we won't bother. + + # If the caller provided their own name and version, they trump the inferred ones. + if(NOT DEFINED CPM_ARGS_NAME) + set(CPM_ARGS_NAME ${nameFromUrl}) + endif() + if(NOT DEFINED CPM_ARGS_VERSION) + set(CPM_ARGS_VERSION ${verFromUrl}) + endif() + + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS URL "${CPM_ARGS_URL}") + endif() + + # Check for required arguments + + if(NOT DEFINED CPM_ARGS_NAME) + message( + FATAL_ERROR + "${CPM_INDENT} 'NAME' was not provided and couldn't be automatically inferred for package added with arguments: '${ARGN}'" + ) + endif() + + # Check if package has been added before + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + if(CPM_PACKAGE_ALREADY_ADDED) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + # Check for manual overrides + if(NOT CPM_ARGS_FORCE AND NOT "${CPM_${CPM_ARGS_NAME}_SOURCE}" STREQUAL "") + set(PACKAGE_SOURCE ${CPM_${CPM_ARGS_NAME}_SOURCE}) + set(CPM_${CPM_ARGS_NAME}_SOURCE "") + CPMAddPackage( + NAME "${CPM_ARGS_NAME}" + SOURCE_DIR "${PACKAGE_SOURCE}" + EXCLUDE_FROM_ALL "${CPM_ARGS_EXCLUDE_FROM_ALL}" + OPTIONS "${CPM_ARGS_OPTIONS}" + SOURCE_SUBDIR "${CPM_ARGS_SOURCE_SUBDIR}" + DOWNLOAD_ONLY "${DOWNLOAD_ONLY}" + FORCE True + ) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + # Check for available declaration + if(NOT CPM_ARGS_FORCE AND NOT "${CPM_DECLARATION_${CPM_ARGS_NAME}}" STREQUAL "") + set(declaration ${CPM_DECLARATION_${CPM_ARGS_NAME}}) + set(CPM_DECLARATION_${CPM_ARGS_NAME} "") + CPMAddPackage(${declaration}) + cpm_export_variables(${CPM_ARGS_NAME}) + # checking again to ensure version and option compatibility + cpm_check_if_package_already_added(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}") + return() + endif() + + if(CPM_USE_LOCAL_PACKAGES OR CPM_LOCAL_PACKAGES_ONLY) + cpm_find_package(${CPM_ARGS_NAME} "${CPM_ARGS_VERSION}" ${CPM_ARGS_FIND_PACKAGE_ARGUMENTS}) + + if(CPM_PACKAGE_FOUND) + cpm_export_variables(${CPM_ARGS_NAME}) + return() + endif() + + if(CPM_LOCAL_PACKAGES_ONLY) + message( + SEND_ERROR + "${CPM_INDENT} ${CPM_ARGS_NAME} not found via find_package(${CPM_ARGS_NAME} ${CPM_ARGS_VERSION})" + ) + endif() + endif() + + CPMRegisterPackage("${CPM_ARGS_NAME}" "${CPM_ARGS_VERSION}") + + if(DEFINED CPM_ARGS_GIT_TAG) + set(PACKAGE_INFO "${CPM_ARGS_GIT_TAG}") + elseif(DEFINED CPM_ARGS_SOURCE_DIR) + set(PACKAGE_INFO "${CPM_ARGS_SOURCE_DIR}") + else() + set(PACKAGE_INFO "${CPM_ARGS_VERSION}") + endif() + + if(DEFINED FETCHCONTENT_BASE_DIR) + # respect user's FETCHCONTENT_BASE_DIR if set + set(CPM_FETCHCONTENT_BASE_DIR ${FETCHCONTENT_BASE_DIR}) + else() + set(CPM_FETCHCONTENT_BASE_DIR ${CMAKE_BINARY_DIR}/_deps) + endif() + + if(DEFINED CPM_ARGS_DOWNLOAD_COMMAND) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS DOWNLOAD_COMMAND ${CPM_ARGS_DOWNLOAD_COMMAND}) + elseif(DEFINED CPM_ARGS_SOURCE_DIR) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS SOURCE_DIR ${CPM_ARGS_SOURCE_DIR}) + if(NOT IS_ABSOLUTE ${CPM_ARGS_SOURCE_DIR}) + # Expand `CPM_ARGS_SOURCE_DIR` relative path. This is important because EXISTS doesn't work + # for relative paths. + get_filename_component( + source_directory ${CPM_ARGS_SOURCE_DIR} REALPATH BASE_DIR ${CMAKE_CURRENT_BINARY_DIR} + ) + else() + set(source_directory ${CPM_ARGS_SOURCE_DIR}) + endif() + if(NOT EXISTS ${source_directory}) + string(TOLOWER ${CPM_ARGS_NAME} lower_case_name) + # remove timestamps so CMake will re-download the dependency + file(REMOVE_RECURSE "${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-subbuild") + endif() + elseif(CPM_SOURCE_CACHE AND NOT CPM_ARGS_NO_CACHE) + string(TOLOWER ${CPM_ARGS_NAME} lower_case_name) + set(origin_parameters ${CPM_ARGS_UNPARSED_ARGUMENTS}) + list(SORT origin_parameters) + if(CPM_USE_NAMED_CACHE_DIRECTORIES) + string(SHA1 origin_hash "${origin_parameters};NEW_CACHE_STRUCTURE_TAG") + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${origin_hash}/${CPM_ARGS_NAME}) + else() + string(SHA1 origin_hash "${origin_parameters}") + set(download_directory ${CPM_SOURCE_CACHE}/${lower_case_name}/${origin_hash}) + endif() + # Expand `download_directory` relative path. This is important because EXISTS doesn't work for + # relative paths. + get_filename_component(download_directory ${download_directory} ABSOLUTE) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS SOURCE_DIR ${download_directory}) + if(EXISTS ${download_directory}) + cpm_store_fetch_properties( + ${CPM_ARGS_NAME} "${download_directory}" + "${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-build" + ) + cpm_get_fetch_properties("${CPM_ARGS_NAME}") + + if(DEFINED CPM_ARGS_GIT_TAG AND NOT (PATCH_COMMAND IN_LIST CPM_ARGS_UNPARSED_ARGUMENTS)) + # warn if cache has been changed since checkout + cpm_check_git_working_dir_is_clean(${download_directory} ${CPM_ARGS_GIT_TAG} IS_CLEAN) + if(NOT ${IS_CLEAN}) + message( + WARNING "${CPM_INDENT} Cache for ${CPM_ARGS_NAME} (${download_directory}) is dirty" + ) + endif() + endif() + + cpm_add_subdirectory( + "${CPM_ARGS_NAME}" "${DOWNLOAD_ONLY}" + "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" "${${CPM_ARGS_NAME}_BINARY_DIR}" + "${CPM_ARGS_EXCLUDE_FROM_ALL}" "${CPM_ARGS_OPTIONS}" + ) + set(PACKAGE_INFO "${PACKAGE_INFO} at ${download_directory}") + + # As the source dir is already cached/populated, we override the call to FetchContent. + set(CPM_SKIP_FETCH TRUE) + cpm_override_fetchcontent( + "${lower_case_name}" SOURCE_DIR "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" + BINARY_DIR "${${CPM_ARGS_NAME}_BINARY_DIR}" + ) + + else() + # Enable shallow clone when GIT_TAG is not a commit hash. Our guess may not be accurate, but + # it should guarantee no commit hash get mis-detected. + if(NOT DEFINED CPM_ARGS_GIT_SHALLOW) + cpm_is_git_tag_commit_hash("${CPM_ARGS_GIT_TAG}" IS_HASH) + if(NOT ${IS_HASH}) + list(APPEND CPM_ARGS_UNPARSED_ARGUMENTS GIT_SHALLOW TRUE) + endif() + endif() + + # remove timestamps so CMake will re-download the dependency + file(REMOVE_RECURSE ${CPM_FETCHCONTENT_BASE_DIR}/${lower_case_name}-subbuild) + set(PACKAGE_INFO "${PACKAGE_INFO} to ${download_directory}") + endif() + endif() + + cpm_create_module_file(${CPM_ARGS_NAME} "CPMAddPackage(\"${ARGN}\")") + + if(CPM_PACKAGE_LOCK_ENABLED) + if((CPM_ARGS_VERSION AND NOT CPM_ARGS_SOURCE_DIR) OR CPM_INCLUDE_ALL_IN_PACKAGE_LOCK) + cpm_add_to_package_lock(${CPM_ARGS_NAME} "${ARGN}") + elseif(CPM_ARGS_SOURCE_DIR) + cpm_add_comment_to_package_lock(${CPM_ARGS_NAME} "local directory") + else() + cpm_add_comment_to_package_lock(${CPM_ARGS_NAME} "${ARGN}") + endif() + endif() + + cpm_message( + STATUS "${CPM_INDENT} Adding package ${CPM_ARGS_NAME}@${CPM_ARGS_VERSION} (${PACKAGE_INFO})" + ) + + if(NOT CPM_SKIP_FETCH) + cpm_declare_fetch( + "${CPM_ARGS_NAME}" "${CPM_ARGS_VERSION}" "${PACKAGE_INFO}" "${CPM_ARGS_UNPARSED_ARGUMENTS}" + ) + cpm_fetch_package("${CPM_ARGS_NAME}" populated) + if(${populated}) + cpm_add_subdirectory( + "${CPM_ARGS_NAME}" "${DOWNLOAD_ONLY}" + "${${CPM_ARGS_NAME}_SOURCE_DIR}/${CPM_ARGS_SOURCE_SUBDIR}" "${${CPM_ARGS_NAME}_BINARY_DIR}" + "${CPM_ARGS_EXCLUDE_FROM_ALL}" "${CPM_ARGS_OPTIONS}" + ) + endif() + cpm_get_fetch_properties("${CPM_ARGS_NAME}") + endif() + + set(${CPM_ARGS_NAME}_ADDED YES) + cpm_export_variables("${CPM_ARGS_NAME}") +endfunction() + +# Fetch a previously declared package +macro(CPMGetPackage Name) + if(DEFINED "CPM_DECLARATION_${Name}") + CPMAddPackage(NAME ${Name}) + else() + message(SEND_ERROR "${CPM_INDENT} Cannot retrieve package ${Name}: no declaration available") + endif() +endmacro() + +# export variables available to the caller to the parent scope expects ${CPM_ARGS_NAME} to be set +macro(cpm_export_variables name) + set(${name}_SOURCE_DIR + "${${name}_SOURCE_DIR}" + PARENT_SCOPE + ) + set(${name}_BINARY_DIR + "${${name}_BINARY_DIR}" + PARENT_SCOPE + ) + set(${name}_ADDED + "${${name}_ADDED}" + PARENT_SCOPE + ) + set(CPM_LAST_PACKAGE_NAME + "${name}" + PARENT_SCOPE + ) +endmacro() + +# declares a package, so that any call to CPMAddPackage for the package name will use these +# arguments instead. Previous declarations will not be overridden. +macro(CPMDeclarePackage Name) + if(NOT DEFINED "CPM_DECLARATION_${Name}") + set("CPM_DECLARATION_${Name}" "${ARGN}") + endif() +endmacro() + +function(cpm_add_to_package_lock Name) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + cpm_prettify_package_arguments(PRETTY_ARGN false ${ARGN}) + file(APPEND ${CPM_PACKAGE_LOCK_FILE} "# ${Name}\nCPMDeclarePackage(${Name}\n${PRETTY_ARGN})\n") + endif() +endfunction() + +function(cpm_add_comment_to_package_lock Name) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + cpm_prettify_package_arguments(PRETTY_ARGN true ${ARGN}) + file(APPEND ${CPM_PACKAGE_LOCK_FILE} + "# ${Name} (unversioned)\n# CPMDeclarePackage(${Name}\n${PRETTY_ARGN}#)\n" + ) + endif() +endfunction() + +# includes the package lock file if it exists and creates a target `cpm-update-package-lock` to +# update it +macro(CPMUsePackageLock file) + if(NOT CPM_DONT_CREATE_PACKAGE_LOCK) + get_filename_component(CPM_ABSOLUTE_PACKAGE_LOCK_PATH ${file} ABSOLUTE) + if(EXISTS ${CPM_ABSOLUTE_PACKAGE_LOCK_PATH}) + include(${CPM_ABSOLUTE_PACKAGE_LOCK_PATH}) + endif() + if(NOT TARGET cpm-update-package-lock) + add_custom_target( + cpm-update-package-lock COMMAND ${CMAKE_COMMAND} -E copy ${CPM_PACKAGE_LOCK_FILE} + ${CPM_ABSOLUTE_PACKAGE_LOCK_PATH} + ) + endif() + set(CPM_PACKAGE_LOCK_ENABLED true) + endif() +endmacro() + +# registers a package that has been added to CPM +function(CPMRegisterPackage PACKAGE VERSION) + list(APPEND CPM_PACKAGES ${PACKAGE}) + set(CPM_PACKAGES + ${CPM_PACKAGES} + CACHE INTERNAL "" + ) + set("CPM_PACKAGE_${PACKAGE}_VERSION" + ${VERSION} + CACHE INTERNAL "" + ) +endfunction() + +# retrieve the current version of the package to ${OUTPUT} +function(CPMGetPackageVersion PACKAGE OUTPUT) + set(${OUTPUT} + "${CPM_PACKAGE_${PACKAGE}_VERSION}" + PARENT_SCOPE + ) +endfunction() + +# declares a package in FetchContent_Declare +function(cpm_declare_fetch PACKAGE VERSION INFO) + if(${CPM_DRY_RUN}) + cpm_message(STATUS "${CPM_INDENT} Package not declared (dry run)") + return() + endif() + + FetchContent_Declare(${PACKAGE} ${ARGN}) +endfunction() + +# returns properties for a package previously defined by cpm_declare_fetch +function(cpm_get_fetch_properties PACKAGE) + if(${CPM_DRY_RUN}) + return() + endif() + + set(${PACKAGE}_SOURCE_DIR + "${CPM_PACKAGE_${PACKAGE}_SOURCE_DIR}" + PARENT_SCOPE + ) + set(${PACKAGE}_BINARY_DIR + "${CPM_PACKAGE_${PACKAGE}_BINARY_DIR}" + PARENT_SCOPE + ) +endfunction() + +function(cpm_store_fetch_properties PACKAGE source_dir binary_dir) + if(${CPM_DRY_RUN}) + return() + endif() + + set(CPM_PACKAGE_${PACKAGE}_SOURCE_DIR + "${source_dir}" + CACHE INTERNAL "" + ) + set(CPM_PACKAGE_${PACKAGE}_BINARY_DIR + "${binary_dir}" + CACHE INTERNAL "" + ) +endfunction() + +# adds a package as a subdirectory if viable, according to provided options +function( + cpm_add_subdirectory + PACKAGE + DOWNLOAD_ONLY + SOURCE_DIR + BINARY_DIR + EXCLUDE + OPTIONS +) + if(NOT DOWNLOAD_ONLY AND EXISTS ${SOURCE_DIR}/CMakeLists.txt) + if(EXCLUDE) + set(addSubdirectoryExtraArgs EXCLUDE_FROM_ALL) + else() + set(addSubdirectoryExtraArgs "") + endif() + if(OPTIONS) + foreach(OPTION ${OPTIONS}) + cpm_parse_option("${OPTION}") + set(${OPTION_KEY} "${OPTION_VALUE}") + endforeach() + endif() + set(CPM_OLD_INDENT "${CPM_INDENT}") + set(CPM_INDENT "${CPM_INDENT} ${PACKAGE}:") + add_subdirectory(${SOURCE_DIR} ${BINARY_DIR} ${addSubdirectoryExtraArgs}) + set(CPM_INDENT "${CPM_OLD_INDENT}") + endif() +endfunction() + +# downloads a previously declared package via FetchContent and exports the variables +# `${PACKAGE}_SOURCE_DIR` and `${PACKAGE}_BINARY_DIR` to the parent scope +function(cpm_fetch_package PACKAGE populated) + set(${populated} + FALSE + PARENT_SCOPE + ) + if(${CPM_DRY_RUN}) + cpm_message(STATUS "${CPM_INDENT} Package ${PACKAGE} not fetched (dry run)") + return() + endif() + + FetchContent_GetProperties(${PACKAGE}) + + string(TOLOWER "${PACKAGE}" lower_case_name) + + if(NOT ${lower_case_name}_POPULATED) + FetchContent_Populate(${PACKAGE}) + set(${populated} + TRUE + PARENT_SCOPE + ) + endif() + + cpm_store_fetch_properties( + ${CPM_ARGS_NAME} ${${lower_case_name}_SOURCE_DIR} ${${lower_case_name}_BINARY_DIR} + ) + + set(${PACKAGE}_SOURCE_DIR + ${${lower_case_name}_SOURCE_DIR} + PARENT_SCOPE + ) + set(${PACKAGE}_BINARY_DIR + ${${lower_case_name}_BINARY_DIR} + PARENT_SCOPE + ) +endfunction() + +# splits a package option +function(cpm_parse_option OPTION) + string(REGEX MATCH "^[^ ]+" OPTION_KEY "${OPTION}") + string(LENGTH "${OPTION}" OPTION_LENGTH) + string(LENGTH "${OPTION_KEY}" OPTION_KEY_LENGTH) + if(OPTION_KEY_LENGTH STREQUAL OPTION_LENGTH) + # no value for key provided, assume user wants to set option to "ON" + set(OPTION_VALUE "ON") + else() + math(EXPR OPTION_KEY_LENGTH "${OPTION_KEY_LENGTH}+1") + string(SUBSTRING "${OPTION}" "${OPTION_KEY_LENGTH}" "-1" OPTION_VALUE) + endif() + set(OPTION_KEY + "${OPTION_KEY}" + PARENT_SCOPE + ) + set(OPTION_VALUE + "${OPTION_VALUE}" + PARENT_SCOPE + ) +endfunction() + +# guesses the package version from a git tag +function(cpm_get_version_from_git_tag GIT_TAG RESULT) + string(LENGTH ${GIT_TAG} length) + if(length EQUAL 40) + # GIT_TAG is probably a git hash + set(${RESULT} + 0 + PARENT_SCOPE + ) + else() + string(REGEX MATCH "v?([0123456789.]*).*" _ ${GIT_TAG}) + set(${RESULT} + ${CMAKE_MATCH_1} + PARENT_SCOPE + ) + endif() +endfunction() + +# guesses if the git tag is a commit hash or an actual tag or a branch name. +function(cpm_is_git_tag_commit_hash GIT_TAG RESULT) + string(LENGTH "${GIT_TAG}" length) + # full hash has 40 characters, and short hash has at least 7 characters. + if(length LESS 7 OR length GREATER 40) + set(${RESULT} + 0 + PARENT_SCOPE + ) + else() + if(${GIT_TAG} MATCHES "^[a-fA-F0-9]+$") + set(${RESULT} + 1 + PARENT_SCOPE + ) + else() + set(${RESULT} + 0 + PARENT_SCOPE + ) + endif() + endif() +endfunction() + +function(cpm_prettify_package_arguments OUT_VAR IS_IN_COMMENT) + set(oneValueArgs + NAME + FORCE + VERSION + GIT_TAG + DOWNLOAD_ONLY + GITHUB_REPOSITORY + GITLAB_REPOSITORY + GIT_REPOSITORY + SOURCE_DIR + DOWNLOAD_COMMAND + FIND_PACKAGE_ARGUMENTS + NO_CACHE + GIT_SHALLOW + ) + set(multiValueArgs OPTIONS) + cmake_parse_arguments(CPM_ARGS "" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + foreach(oneArgName ${oneValueArgs}) + if(DEFINED CPM_ARGS_${oneArgName}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + if(${oneArgName} STREQUAL "SOURCE_DIR") + string(REPLACE ${CMAKE_SOURCE_DIR} "\${CMAKE_SOURCE_DIR}" CPM_ARGS_${oneArgName} + ${CPM_ARGS_${oneArgName}} + ) + endif() + string(APPEND PRETTY_OUT_VAR " ${oneArgName} ${CPM_ARGS_${oneArgName}}\n") + endif() + endforeach() + foreach(multiArgName ${multiValueArgs}) + if(DEFINED CPM_ARGS_${multiArgName}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " ${multiArgName}\n") + foreach(singleOption ${CPM_ARGS_${multiArgName}}) + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " \"${singleOption}\"\n") + endforeach() + endif() + endforeach() + + if(NOT "${CPM_ARGS_UNPARSED_ARGUMENTS}" STREQUAL "") + if(${IS_IN_COMMENT}) + string(APPEND PRETTY_OUT_VAR "#") + endif() + string(APPEND PRETTY_OUT_VAR " ") + foreach(CPM_ARGS_UNPARSED_ARGUMENT ${CPM_ARGS_UNPARSED_ARGUMENTS}) + string(APPEND PRETTY_OUT_VAR " ${CPM_ARGS_UNPARSED_ARGUMENT}") + endforeach() + string(APPEND PRETTY_OUT_VAR "\n") + endif() + + set(${OUT_VAR} + ${PRETTY_OUT_VAR} + PARENT_SCOPE + ) + +endfunction() diff --git a/include/math_approx/math_approx.hpp b/include/math_approx/math_approx.hpp new file mode 100644 index 0000000..0ee9cdf --- /dev/null +++ b/include/math_approx/math_approx.hpp @@ -0,0 +1,12 @@ +#pragma once + +namespace math_approx +{ +} + +#include "src/basic_math.hpp" + +#include "src/pow_approx.hpp" + +#include "src/tanh_approx.hpp" +#include "src/sigmoid_approx.hpp" diff --git a/include/math_approx/src/basic_math.hpp b/include/math_approx/src/basic_math.hpp new file mode 100644 index 0000000..8ab9e85 --- /dev/null +++ b/include/math_approx/src/basic_math.hpp @@ -0,0 +1,43 @@ +#pragma once + +// If MATH_APPROX_XSIMD_TARGET is not defined +// the user can still use XSIMD by manually including +// it before including the math_approx header. +#if MATH_APPROX_XSIMD_TARGET +#include +#endif + +#if ! defined(XSIMD_HPP) +#include +#endif + +namespace math_approx +{ +template +T rsqrt (T x) +{ + // sqrtss followed by divss... this seems to measure a bit faster than the rsqrtss plus NR iteration below + return (T) 1 / std::sqrt (x); + + // fast inverse square root (using rsqrtss hardware instruction), plus one Newton-Raphson iteration + // auto r = xsimd::rsqrt (xsimd::broadcast (x)).get (0); + // x *= r; + // x *= r; + // x += -3.0f; + // r *= -0.5f; + // return x * r; +} + +#if defined(XSIMD_HPP) +template +xsimd::batch rsqrt (xsimd::batch x) +{ + auto r = xsimd::rsqrt (x); + x *= r; + x *= r; + x += -3.0f; + r *= -0.5f; + return x * r; +} +#endif +} // namespace math_approx diff --git a/include/math_approx/src/sigmoid_approx.hpp b/include/math_approx/src/sigmoid_approx.hpp new file mode 100644 index 0000000..e29163c --- /dev/null +++ b/include/math_approx/src/sigmoid_approx.hpp @@ -0,0 +1,76 @@ +#pragma once + +#include "basic_math.hpp" +// #include "pow_approx.hpp" + +namespace math_approx +{ +namespace sigmoid_detail +{ + // These polynomial fits were generated from: https://www.wolframcloud.com/obj/chowdsp/Published/sigmoid_approx.nb + + template + T sig_poly_9 (T x) + { + const auto x_sq = x * x; + const auto y_7_9 = (T) 1.50024356624e-6 + (T) 6.92468584642e-9 * x_sq; + const auto y_5_7_9 = (T) 0.000260923534301 + y_7_9 * x_sq; + const auto y_3_5_7_9 = (T) 0.0208320229264 + y_5_7_9 * x_sq; + const auto y_1_3_5_7_9 = (T) 0.5 + y_3_5_7_9 * x_sq; + return x * y_1_3_5_7_9; + } + + template + T sig_poly_7 (T x) + { + const auto x_sq = x * x; + const auto y_5_7 = (T) 0.000255174491559 + (T) 1.90805380557e-6 * x_sq; + const auto y_3_5_7 = (T) 0.0208503675870 + y_5_7 * x_sq; + const auto y_1_3_5_7 = (T) 0.5 + y_3_5_7 * x_sq; + return x * y_1_3_5_7; + } + + template + T sig_poly_5 (T x) + { + const auto x_sq = x * x; + const auto y_3_5 = (T) 0.0206108521251 + (T) 0.000307906311109 * x_sq; + const auto y_1_3_5 = (T) 0.5 + y_3_5 * x_sq; + return x * y_1_3_5; + } + + template + T sig_poly_3 (T x) + { + const auto x_sq = x * x; + const auto y_1_3 = (T) 0.5 + (T) 0.0233402955195 * x_sq; + return x * y_1_3; + } +} // namespace sigmoid_detail + +template +T sigmoid (T x) +{ + static_assert (order % 2 == 1 && order <= 9 && order >= 3, "Order must e an odd number within [3, 9]"); + + T x_poly {}; + if constexpr (order == 9) + x_poly = sigmoid_detail::sig_poly_9 (x); + else if constexpr (order == 7) + x_poly = sigmoid_detail::sig_poly_7 (x); + else if constexpr (order == 5) + x_poly = sigmoid_detail::sig_poly_5 (x); + else if constexpr (order == 3) + x_poly = sigmoid_detail::sig_poly_3 (x); + + return (T) 0.5 * x_poly * rsqrt (x_poly * x_poly + (T) 1) + (T) 0.5; +} + +// So far this has tested slower than the above approx (for equivalent error), +// but maybe it will be useful for someone! +// template +// T sigmoid_exp (T x) +// { +// return (T) 1 / ((T) 1 + math_approx::exp (-x)); +// } +} // namespace math_approx diff --git a/include/math_approx/src/tanh_approx.hpp b/include/math_approx/src/tanh_approx.hpp new file mode 100644 index 0000000..d8d1f13 --- /dev/null +++ b/include/math_approx/src/tanh_approx.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include "basic_math.hpp" + +namespace math_approx +{ +namespace tanh_detail +{ + // These polynomial fits were generated from: https://www.wolframcloud.com/obj/chowdsp/Published/tanh_approx.nb + + template + T tanh_poly_11 (T x) + { + const auto x_sq = x * x; + const auto y_9_11 = (T) 2.63661358122e-6 + (T) 3.33765558362e-8 * x_sq; + const auto y_7_9_11 = (T) 0.000199027336899 + y_9_11 * x_sq; + const auto y_5_7_9_11 = (T) 0.00833223857843 + y_7_9_11 * x_sq; + const auto y_3_5_7_9_11 = (T) 0.166667159320 + y_5_7_9_11 * x_sq; + const auto y_1_3_5_7_9_11 = (T) 1 + y_3_5_7_9_11 * x_sq; + return x * y_1_3_5_7_9_11; + } + + template + T tanh_poly_9 (T x) + { + const auto x_sq = x * x; + const auto y_7_9 = (T) 0.000192218110330 + (T) 3.54808622170e-6 * x_sq; + const auto y_5_7_9 = (T) 0.00834777254865 + y_7_9 * x_sq; + const auto y_3_5_7_9 = (T) 0.166658873283 + y_5_7_9 * x_sq; + const auto y_1_3_5_7_9 = (T) 1 + y_3_5_7_9 * x_sq; + return x * y_1_3_5_7_9; + } + + template + T tanh_poly_7 (T x) + { + const auto x_sq = x * x; + const auto y_5_7 = (T) 0.00818199927912 + (T) 0.000243153287690 * x_sq; + const auto y_3_5_7 = (T) 0.166769941467 + y_5_7 * x_sq; + const auto y_1_3_5_7 = (T) 1 + y_3_5_7 * x_sq; + return x * y_1_3_5_7; + } + + template + T tanh_poly_5 (T x) + { + const auto x_sq = x * x; + const auto y_3_5 = (T) 0.165326984031 + (T) 0.00970240200826 * x_sq; + const auto y_1_3_5 = (T) 1 + y_3_5 * x_sq; + return x * y_1_3_5; + } + + template + T tanh_poly_3 (T x) + { + const auto x_sq = x * x; + const auto y_1_3 = (T) 1 + (T) 0.183428244899 * x_sq; + return x * y_1_3; + } +} // namespace tanh_detail + +template +T tanh (T x) +{ + static_assert (order % 2 == 1 && order <= 11 && order >= 3, "Order must e an odd number within [3, 9]"); + + T x_poly {}; + if constexpr (order == 11) + x_poly = tanh_detail::tanh_poly_11 (x); + else if constexpr (order == 9) + x_poly = tanh_detail::tanh_poly_9 (x); + else if constexpr (order == 7) + x_poly = tanh_detail::tanh_poly_7 (x); + else if constexpr (order == 5) + x_poly = tanh_detail::tanh_poly_5 (x); + else if constexpr (order == 3) + x_poly = tanh_detail::tanh_poly_3 (x); + + return x_poly * rsqrt (x_poly * x_poly + (T) 1); +} +} // namespace math_approx diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..0be4d0b --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,32 @@ +message(STATUS "math_approx -- Configuring tests...") + +CPMAddPackage("gh:catchorg/Catch2@3.2.1") +include(${Catch2_SOURCE_DIR}/extras/Catch.cmake) + +function(setup_catch_test target) + add_executable(${target}) + target_sources(${target} PRIVATE src/${target}.cpp) + target_include_directories(${target} PRIVATE ${CMAKE_SOURCE_DIR}/tests/test_utils) + target_link_libraries(${target} + PRIVATE + Catch2::Catch2WithMain + math_approx + ) + target_compile_options(${target} PRIVATE + $<$:/W4 /WX> + $<$>:-Wall -Wextra -Wpedantic -Werror> + ) + + add_custom_command(TARGET ${target} + POST_BUILD + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E echo "Copying $ to test-binary" + COMMAND ${CMAKE_COMMAND} -E make_directory test-binary + COMMAND ${CMAKE_COMMAND} -E copy "$" test-binary + ) + + catch_discover_tests(${target} TEST_PREFIX ${target}_) +endfunction(setup_catch_test) + +setup_catch_test(tanh_approx_test) +setup_catch_test(sigmoid_approx_test) diff --git a/test/src/sigmoid_approx_test.cpp b/test/src/sigmoid_approx_test.cpp new file mode 100644 index 0000000..f9cafd7 --- /dev/null +++ b/test/src/sigmoid_approx_test.cpp @@ -0,0 +1,48 @@ +#include "test_helpers.hpp" +#include +#include + +#include + +TEST_CASE ("Sigmoid Approx Test") +{ + const auto all_floats = test_helpers::all_32_bit_floats (-20.0f, 20.0f, 1.0e-5f); + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return 1.0f / (1.0f + std::exp (-x)); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto max_error = test_helpers::abs_max (error); + + // std::cout << max_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + }; + + SECTION ("9th-Order") + { + test_approx ([] (auto x) + { return math_approx::sigmoid<9> (x); }, + 6.1e-7f); + } + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::sigmoid<7> (x); }, + 6.8e-6f); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::sigmoid<5> (x); }, + 9.7e-5f); + } + SECTION ("3th-Order") + { + test_approx ([] (auto x) + { return math_approx::sigmoid<3> (x); }, + 1.7e-3f); + } +} diff --git a/test/src/tanh_approx_test.cpp b/test/src/tanh_approx_test.cpp new file mode 100644 index 0000000..bcfdb83 --- /dev/null +++ b/test/src/tanh_approx_test.cpp @@ -0,0 +1,72 @@ +#include "test_helpers.hpp" +#include +#include + +#include + +TEST_CASE ("Tanh Approx Test") +{ + const auto all_floats = test_helpers::all_32_bit_floats (-10.0f, 10.0f, 1.0e-5f); + const auto y_exact = test_helpers::compute_all (all_floats, [] (auto x) + { return std::tanh (x); }); + + const auto test_approx = [&all_floats, &y_exact] (auto&& f_approx, float err_bound, float rel_err_bound, uint32_t ulp_bound) + { + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + + const auto error = test_helpers::compute_error (y_exact, y_approx); + const auto rel_error = test_helpers::compute_rel_error (y_exact, y_approx); + const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx); + + const auto max_error = test_helpers::abs_max (error); + const auto max_rel_error = test_helpers::abs_max (rel_error); + const auto max_ulp_error = *std::max_element (ulp_error.begin(), ulp_error.end()); + + std::cout << max_error << ", " << max_rel_error << ", " << max_ulp_error << std::endl; + REQUIRE (std::abs (max_error) < err_bound); + REQUIRE (std::abs (max_rel_error) < rel_err_bound); + if (ulp_bound > 0) + REQUIRE (max_ulp_error < ulp_bound); + }; + + SECTION ("11th-Order") + { + test_approx ([] (auto x) + { return math_approx::tanh<11> (x); }, + 1.9e-7f, + 3.4e-7f, + 5); + } + SECTION ("9th-Order") + { + test_approx ([] (auto x) + { return math_approx::tanh<9> (x); }, + 1.2e-6f, + 1.3e-6f, + 18); + } + SECTION ("7th-Order") + { + test_approx ([] (auto x) + { return math_approx::tanh<7> (x); }, + 1.4e-5f, + 1.5e-5f, + 226); + } + SECTION ("5th-Order") + { + test_approx ([] (auto x) + { return math_approx::tanh<5> (x); }, + 2.2e-4f, + 2.2e-4f, + 0); + } + SECTION ("3th-Order") + { + test_approx ([] (auto x) + { return math_approx::tanh<3> (x); }, + 3.7e-3f, + 3.8e-3f, + 0); + } +} diff --git a/test/src/test_helpers.hpp b/test/src/test_helpers.hpp new file mode 100644 index 0000000..f1453ef --- /dev/null +++ b/test/src/test_helpers.hpp @@ -0,0 +1,124 @@ +#pragma once + +#include +#include +#include +#include + +namespace test_helpers +{ +inline auto all_32_bit_floats (float begin, float end, float tol = 1.0e-10f) +{ + std::vector vec; + vec.reserve (1 << 20); + begin = vec.emplace_back (begin); + while (begin < end) + { + if (std::abs (begin) < tol) + { + begin = vec.emplace_back (0.0f); + begin = vec.emplace_back (tol); + } + begin = vec.emplace_back (std::nextafter (begin, end)); + } + + return vec; +} + +template +auto compute_all (std::span all_floats, + F&& f) +{ + std::vector y; + y.resize (all_floats.size()); + for (size_t i = 0; i < all_floats.size(); ++i) + y[i] = f (all_floats[i]); + + return y; +} + +inline std::vector compute_error (std::span actual, std::span approx) +{ + std::vector err; + err.resize (actual.size()); + for (size_t i = 0; i < actual.size(); ++i) + err[i] = (actual[i] - approx[i]); + return err; +} + +inline std::vector compute_rel_error (std::span actual, std::span approx) +{ + std::vector err; + err.resize (actual.size()); + for (size_t i = 0; i < actual.size(); ++i) + err[i] = (actual[i] - approx[i]) / actual[i]; + return err; +} + +// mostly borrowed from Catch2 +inline uint32_t f32_ulp_dist (float lhs, float rhs) // NOLINT +{ + // We want X == Y to imply 0 ULP distance even if X and Y aren't + // bit-equal (-0 and 0), or X - Y != 0 (same sign infinities). + if (lhs == rhs) + return 0; + + // We need a properly typed positive zero for type inference. + static constexpr float positive_zero {}; + + // We want to ensure that +/- 0 is always represented as positive zero + if (lhs == positive_zero) + lhs = positive_zero; + if (rhs == positive_zero) + rhs = positive_zero; + + // If arguments have different signs, we can handle them by summing + // how far are they from 0 each. + if (std::signbit (lhs) != std::signbit (rhs)) + { + return f32_ulp_dist (std::abs (lhs), positive_zero) + + f32_ulp_dist (std::abs (rhs), positive_zero); + } + + // get the bit pattern of 'x' + const auto f32_to_bits = [] (float x) -> uint32_t + { + uint32_t u; + memcpy (&u, &x, 4); + return u; + }; + + // When both lhs and rhs are of the same sign, we can just + // read the numbers bitwise as integers, and then subtract them + // (assuming IEEE). + uint32_t lc = f32_to_bits (lhs); + uint32_t rc = f32_to_bits (rhs); + + // The ulp distance between two numbers is symmetric, so to avoid + // dealing with overflows we want the bigger converted number on the lhs + if (lc < rc) + std::swap (lc, rc); + + return lc - rc; +}; + +inline auto compute_ulp_error (std::span actual, std::span approx) +{ + + + std::vector err; + err.resize (actual.size()); + for (size_t i = 0; i < actual.size(); ++i) + err[i] = f32_ulp_dist (actual[i], approx[i]); + return err; +} + +inline float abs_max (std::span x) +{ + const auto [min, max] = std::minmax_element (x.begin(), x.end()); + + if (std::abs (*min) > std::abs (*max)) + return *min; + return *max; +} +} // namespace test_helpers diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt new file mode 100644 index 0000000..d893c36 --- /dev/null +++ b/tools/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(plotter) +add_subdirectory(bench) diff --git a/tools/bench/CMakeLists.txt b/tools/bench/CMakeLists.txt new file mode 100644 index 0000000..d8b4916 --- /dev/null +++ b/tools/bench/CMakeLists.txt @@ -0,0 +1,12 @@ +CPMAddPackage( + NAME benchmark + GITHUB_REPOSITORY google/benchmark + VERSION 1.5.2 + OPTIONS "BENCHMARK_ENABLE_TESTING Off" +) + +add_executable(tanh_approx_bench tanh_bench.cpp) +target_link_libraries(tanh_approx_bench PRIVATE benchmark::benchmark math_approx) + +add_executable(sigmoid_approx_bench sigmoid_bench.cpp) +target_link_libraries(sigmoid_approx_bench PRIVATE benchmark::benchmark math_approx) diff --git a/tools/bench/sigmoid_bench.cpp b/tools/bench/sigmoid_bench.cpp new file mode 100644 index 0000000..5f868f5 --- /dev/null +++ b/tools/bench/sigmoid_bench.cpp @@ -0,0 +1,51 @@ +#include +#include + +static constexpr size_t N = 2000; +const auto data = [] +{ + std::vector x; + x.resize (N, 0.0f); + for (size_t i = 0; i < N; ++i) + x[i] = -10.0f + 20.0f * (float) i / (float) N; + return x; +}(); + +#define SIGMOID_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (x); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); +SIGMOID_BENCH (sigmoid_std, [] (auto x) { return 1.0f / (1.0f + std::exp (-x)); }) +SIGMOID_BENCH (sigmoid_approx9, math_approx::sigmoid<9>) +SIGMOID_BENCH (sigmoid_approx7, math_approx::sigmoid<7>) +SIGMOID_BENCH (sigmoid_approx5, math_approx::sigmoid<5>) + +#define SIGMOID_SIMD_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ +for (auto _ : state) \ +{ \ +for (auto& x : data) \ +{ \ +auto y = func (xsimd::broadcast (x)); \ +static_assert (std::is_same_v, decltype(y)>); \ +benchmark::DoNotOptimize (y); \ +} \ +} \ +} \ +BENCHMARK (name); +SIGMOID_SIMD_BENCH (sigmoid_xsimd, [] (auto x) { return 1.0f / (1.0f + xsimd::exp (-x)); }) +SIGMOID_SIMD_BENCH (sigmoid_simd_approx9, math_approx::tanh<9>) +SIGMOID_SIMD_BENCH (sigmoid_simd_approx7, math_approx::tanh<7>) +SIGMOID_SIMD_BENCH (sigmoid_simd_approx5, math_approx::tanh<5>) + +BENCHMARK_MAIN(); diff --git a/tools/bench/tanh_bench.cpp b/tools/bench/tanh_bench.cpp new file mode 100644 index 0000000..b9de365 --- /dev/null +++ b/tools/bench/tanh_bench.cpp @@ -0,0 +1,53 @@ +#include +#include + +static constexpr size_t N = 2000; +const auto data = [] +{ + std::vector x; + x.resize (N, 0.0f); + for (size_t i = 0; i < N; ++i) + x[i] = -10.0f + 20.0f * (float) i / (float) N; + return x; +}(); + +#define TANH_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ + for (auto _ : state) \ + { \ + for (auto& x : data) \ + { \ + auto y = func (x); \ + benchmark::DoNotOptimize (y); \ + } \ + } \ +} \ +BENCHMARK (name); +TANH_BENCH (tanh_std, std::tanh) +TANH_BENCH (tanh_approx11, math_approx::tanh<11>) +TANH_BENCH (tanh_approx9, math_approx::tanh<9>) +TANH_BENCH (tanh_approx7, math_approx::tanh<7>) +TANH_BENCH (tanh_approx5, math_approx::tanh<5>) + +#define TANH_SIMD_BENCH(name, func) \ +void name (benchmark::State& state) \ +{ \ + for (auto _ : state) \ + { \ + for (auto& x : data) \ + { \ + auto y = func (xsimd::broadcast (x)); \ + static_assert (std::is_same_v, decltype(y)>); \ + benchmark::DoNotOptimize (y); \ + } \ + } \ +} \ +BENCHMARK (name); +TANH_SIMD_BENCH (tanh_xsimd, xsimd::tanh) +TANH_SIMD_BENCH (tanh_simd_approx11, math_approx::tanh<11>) +TANH_SIMD_BENCH (tanh_simd_approx9, math_approx::tanh<9>) +TANH_SIMD_BENCH (tanh_simd_approx7, math_approx::tanh<7>) +TANH_SIMD_BENCH (tanh_simd_approx5, math_approx::tanh<5>) + +BENCHMARK_MAIN(); diff --git a/tools/plotter/CMakeLists.txt b/tools/plotter/CMakeLists.txt new file mode 100644 index 0000000..6e6d33d --- /dev/null +++ b/tools/plotter/CMakeLists.txt @@ -0,0 +1,8 @@ +CPMAddPackage( + NAME matplotlib-cpp + GIT_REPOSITORY https://github.com/jatinchowdhury18/matplotlib-cpp + GIT_TAG main +) + +add_executable(math_approx_plotter plotter.cpp) +target_link_libraries(math_approx_plotter PRIVATE matplotlib-cpp math_approx) diff --git a/tools/plotter/plotter.cpp b/tools/plotter/plotter.cpp new file mode 100644 index 0000000..3fab040 --- /dev/null +++ b/tools/plotter/plotter.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include + +#include +namespace plt = matplotlibcpp; + +#include "../../test/src/test_helpers.hpp" +#include + +template +void plot_error (std::span all_floats, + std::span y_exact, + F_Approx&& f_approx, + const std::string& name) +{ + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + const auto error = test_helpers::compute_error (y_exact, y_approx); + std::cout << "Max Error: " << test_helpers::abs_max (error) << std::endl; + plt::named_plot (name, all_floats, error); +} + +template +void plot_rel_error (std::span all_floats, + std::span y_exact, + F_Approx&& f_approx, + const std::string& name) +{ + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + const auto rel_error = test_helpers::compute_rel_error (y_exact, y_approx); + std::cout << "Max Relative Error: " << test_helpers::abs_max (rel_error) << std::endl; + plt::named_plot (name, all_floats, rel_error); +} + +template +void plot_ulp_error (std::span all_floats, + std::span y_exact, + F_Approx&& f_approx, + const std::string& name) +{ + const auto y_approx = test_helpers::compute_all (all_floats, f_approx); + const auto ulp_error = test_helpers::compute_ulp_error (y_exact, y_approx); + std::cout << "Max Relative Error: " << *std::max_element (ulp_error.begin(), ulp_error.end()) << std::endl; + plt::named_plot (name, all_floats, std::vector { ulp_error.begin(), ulp_error.end() }); +} + +template +void plot_function (std::span all_floats, + F&& f, + const std::string& name) +{ + const auto y_approx = test_helpers::compute_all (all_floats, f); + plt::named_plot (name, all_floats, y_approx); +} + +int main() +{ + plt::figure(); + const auto range = std::make_pair (-10.0f, 10.0f); + static constexpr auto tol = 1.0e-2f; + + const auto all_floats = test_helpers::all_32_bit_floats (range.first, range.second, tol); + const auto y_exact = test_helpers::compute_all (all_floats, [] (float x) + { return 1.0f / (1.0f + std::exp (-x)); }); + + plot_error ( + all_floats, + y_exact, + [] (float x) + { return math_approx::sigmoid_exp<3> (x); }, + "Sigmoid-Exp-5"); + + plt::legend ({ { "loc", "upper right" } }); + plt::xlim (range.first, range.second); + plt::grid (true); + plt::show(); + + return 0; +}