Skip to content

Commit

Permalink
Use a single template for every operator between clad::array_expressi…
Browse files Browse the repository at this point in the history
…on, clad::array, clad::array_ref
  • Loading branch information
PetroZarytskyi committed Nov 14, 2024
1 parent b30a742 commit 91d2395
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 276 deletions.
100 changes: 0 additions & 100 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,6 @@ template <typename T> class array {
*this);
}

/// Subtracts the number from every element in the array and returns a new
/// array, when the number is on the left side.
template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
int>::type = 0>
CUDA_HOST_DEVICE friend array_expression<U, BinarySub, const array<T>&>
operator-(U n, const array<T>& arr) {
return array_expression<U, BinarySub, const array<T>&>(n, arr);
}

/// Implicitly converts from clad::array to pointer to an array of type T
CUDA_HOST_DEVICE operator T*() const { return m_arr; }
}; // class array
Expand All @@ -355,97 +346,6 @@ template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {
return array<T>(n);
}

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, U>
operator*(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryMul, U>(arr, n);
}

/// Multiplies the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, U>
operator*(U n, const array<T>& arr) {
return array_expression<const array<T>&, BinaryMul, U>(arr, n);
}

/// Divides the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryDiv, U>
operator/(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryDiv, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, U>
operator+(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryAdd, U>(arr, n);
}

/// Adds the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, U>
operator+(U n, const array<T>& arr) {
return array_expression<const array<T>&, BinaryAdd, U>(arr, n);
}

/// Subtracts the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinarySub, U>
operator-(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinarySub, U>(arr, n);
}

/// Function to define element wise adding of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, const array<U>&>
operator+(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryAdd, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise subtraction of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinarySub, const array<U>&>
operator-(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinarySub, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise multiplication of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, const array<U>&>
operator*(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryMul, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise division of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryDiv, const array<U>&>
operator/(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryDiv, const array<U>&>(arr1,
arr2);
}

} // namespace clad

#endif // CLAD_ARRAY_H
119 changes: 51 additions & 68 deletions include/clad/Differentiator/ArrayExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,82 +81,65 @@ class array_expression {
}

std::size_t size() const { return std::max(get_size(l), get_size(r)); }
};

// Operator overload for addition.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryAdd, RE>
operator+(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryAdd, RE>(
*this, r);
}
// A template class to determine whether a given type is array_expression, array
// or array_ref.
template <typename T> class array;
template <typename T> class array_ref;

// Operator overload for multiplication.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryMul, RE>
operator*(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryMul, RE>(
*this, r);
}

// Operator overload for subtraction.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinarySub, RE>
operator-(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinarySub, RE>(
*this, r);
}
template <typename T> struct is_clad_type : std::false_type {};

// Operator overload for division.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryDiv, RE>
operator/(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryDiv, RE>(
*this, r);
}
};
template <typename LeftExp, typename BinaryOp, typename RightExp>
struct is_clad_type<array_expression<LeftExp, BinaryOp, RightExp>>
: std::true_type {};

template <typename T> struct is_clad_type<array<T>> : std::true_type {};

template <typename T> struct is_clad_type<array_ref<T>> : std::true_type {};

// Operator overload for addition, when one of the operands is array_expression,
// array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryAdd, const T2&> operator+(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for addition, when the right operand is an array_expression
// and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator+(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for multiplication, when one of the operands is
// array_expression, array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryMul, const T2&> operator*(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for multiplication, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator*(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for subtraction, when one of the operands is
// array_expression, array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinarySub, const T2&> operator-(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for subtraction, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<std::is_arithmetic<T>::value, int>::type = 0>
array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator-(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for division, when one of the operands is array_expression,
// array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryDiv, const T2&> operator/(const T1& l,
const T2& r) {
return {l, r};
}
} // namespace clad
// NOLINTEND(*-pointer-arithmetic)
Expand Down
108 changes: 0 additions & 108 deletions include/clad/Differentiator/ArrayRef.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,114 +216,6 @@ template <typename T> class array_ref {
}
};

/// Overloaded operators for clad::array_ref which returns an array
/// expression.

/// Multiplies the arrays element wise
template <typename T, typename U>
constexpr CUDA_HOST_DEVICE
array_expression<const array_ref<T>&, BinaryMul, const array_ref<U>&>
operator*(const array_ref<T>& Ar, const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out "
"multiplication assignment");
return array_expression<const array_ref<T>&, BinaryMul, const array_ref<U>&>(
Ar, Br);
}

/// Adds the arrays element wise
template <typename T, typename U>
constexpr CUDA_HOST_DEVICE
array_expression<const array_ref<T>&, BinaryAdd, const array_ref<U>&>
operator+(const array_ref<T>& Ar, const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out addition "
"assignment");
return array_expression<const array_ref<T>&, BinaryAdd, const array_ref<U>&>(
Ar, Br);
}

/// Subtracts the arrays element wise
template <typename T, typename U>
constexpr CUDA_HOST_DEVICE
array_expression<const array_ref<T>&, BinarySub, const array_ref<U>&>
operator-(const array_ref<T>& Ar, const array_ref<U>& Br) {
assert(
Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out subtraction "
"assignment");
return array_expression<const array_ref<T>&, BinarySub, const array_ref<U>&>(
Ar, Br);
}

/// Divides the arrays element wise
template <typename T, typename U>
constexpr CUDA_HOST_DEVICE
array_expression<const array_ref<T>&, BinaryDiv, const array_ref<U>&>
operator/(const array_ref<T>& Ar, const array_ref<U>& Br) {
assert(Ar.size() == Br.size() &&
"Size of both the array_refs must be equal for carrying out division "
"assignment");
return array_expression<const array_ref<T>&, BinaryDiv, const array_ref<U>&>(
Ar, Br);
}

/// Multiplies array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<const array_ref<T>&, BinaryMul, U>
operator*(const array_ref<T>& Ar, U a) {
return array_expression<const array_ref<T>&, BinaryMul, U>(Ar, a);
}

/// Multiplies array_ref by a scalar (reverse order)
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<const array_ref<T>&, BinaryMul, U>
operator*(U a, const array_ref<T>& Ar) {
return array_expression<const array_ref<T>&, BinaryMul, U>(Ar, a);
}

/// Divides array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<const array_ref<T>&, BinaryDiv, U>
operator/(const array_ref<T>& Ar, U a) {
return array_expression<const array_ref<T>&, BinaryDiv, U>(Ar, a);
}

/// Adds array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<const array_ref<T>&, BinaryAdd, U>
operator+(const array_ref<T>& Ar, U a) {
return array_expression<const array_ref<T>&, BinaryAdd, U>(Ar, a);
}

/// Adds array_ref by a scalar (reverse order)
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<const array_ref<T>&, BinaryAdd, U>
operator+(U a, const array_ref<T>& Ar) {
return array_expression<const array_ref<T>&, BinaryAdd, U>(Ar, a);
}

/// Subtracts array_ref by a scalar
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<const array_ref<T>&, BinarySub, U>
operator-(const array_ref<T>& Ar, U a) {
return array_expression<const array_ref<T>&, BinarySub, U>(Ar, a);
}

/// Subtracts array_ref by a scalar (reverse order)
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
constexpr CUDA_HOST_DEVICE array_expression<U, BinarySub, const array_ref<T>&>
operator-(U a, const array_ref<T>& Ar) {
return array_expression<U, BinarySub, const array_ref<T>&>(a, Ar);
}

/// `array_ref<void>` specialisation is created to be used as a placeholder
/// type in the overloaded derived function. All `array_ref<T>` types are
/// implicitly convertible to `array_ref<void>` type.
Expand Down

0 comments on commit 91d2395

Please sign in to comment.