Skip to content

Commit

Permalink
Use a single template for every operator between clad::array_expressi…
Browse files Browse the repository at this point in the history
…on, clad::array, clad::array_ref
  • Loading branch information
PetroZarytskyi committed Nov 14, 2024
1 parent 96eb05d commit fc48ba5
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 320 deletions.
140 changes: 0 additions & 140 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,6 @@ template <typename T> class array {
*this);
}

/// Subtracts the number from every element in the array and returns a new
/// array, when the number is on the left side.
template <typename U, typename std::enable_if<std::is_arithmetic<U>::value,
int>::type = 0>
CUDA_HOST_DEVICE friend array_expression<U, BinarySub, const array<T>&>
operator-(U n, const array<T>& arr) {
return array_expression<U, BinarySub, const array<T>&>(n, arr);
}

/// Implicitly converts from clad::array to pointer to an array of type T
CUDA_HOST_DEVICE operator T*() const { return m_arr; }
}; // class array
Expand All @@ -355,137 +346,6 @@ template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {
return array<T>(n);
}

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, U>
operator*(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryMul, U>(arr, n);
}

/// Multiplies the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, U>
operator*(U n, const array<T>& arr) {
return array_expression<const array<T>&, BinaryMul, U>(arr, n);
}

/// Divides the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryDiv, U>
operator/(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryDiv, U>(arr, n);
}

/// Adds the number to every element in the array and returns a new array
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, U>
operator+(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinaryAdd, U>(arr, n);
}

/// Adds the number to every element in the array and returns an array
/// expression, when the number is on the left side.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, U>
operator+(U n, const array<T>& arr) {
return array_expression<const array<T>&, BinaryAdd, U>(arr, n);
}

/// Subtracts the number from every element in the array and returns an array
/// expression.
template <typename T, typename U,
typename std::enable_if<std::is_arithmetic<U>::value, int>::type = 0>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinarySub, U>
operator-(const array<T>& arr, U n) {
return array_expression<const array<T>&, BinarySub, U>(arr, n);
}

/// Function to define element wise adding of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryAdd, const array<U>&>
operator+(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryAdd, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise subtraction of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinarySub, const array<U>&>
operator-(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinarySub, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise multiplication of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryMul, const array<U>&>
operator*(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryMul, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise division of two arrays.
template <typename T, typename U>
CUDA_HOST_DEVICE array_expression<const array<T>&, BinaryDiv, const array<U>&>
operator/(const array<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryDiv, const array<U>&>(arr1,
arr2);
}

/// Function to define element wise adding of an arrays and an array_ref.
template <typename T, typename U>
CUDA_HOST_DEVICE
array_expression<const array<T>&, BinaryAdd, const array_ref<U>&>
operator+(const array<T>& arr1, const array_ref<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinaryAdd, const array_ref<U>&>(
arr1, arr2);
}

/// Function to define element wise adding of an arrays_ref and an array.
template <typename T, typename U>
CUDA_HOST_DEVICE
array_expression<const array_ref<T>&, BinaryAdd, const array<U>&>
operator+(const array_ref<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array_ref<T>&, BinaryAdd, const array<U>&>(
arr1, arr2);
}

/// Function to define element wise adding of an arrays and an array_ref.
template <typename T, typename U>
CUDA_HOST_DEVICE
array_expression<const array<T>&, BinarySub, const array_ref<U>&>
operator-(const array<T>& arr1, const array_ref<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array<T>&, BinarySub, const array_ref<U>&>(
arr1, arr2);
}

/// Function to define element wise adding of an arrays_ref and an array.
template <typename T, typename U>
CUDA_HOST_DEVICE
array_expression<const array_ref<T>&, BinarySub, const array<U>&>
operator-(const array_ref<T>& arr1, const array<U>& arr2) {
assert(arr1.size() == arr2.size());
return array_expression<const array_ref<T>&, BinarySub, const array<U>&>(
arr1, arr2);
}

} // namespace clad

#endif // CLAD_ARRAY_H
120 changes: 48 additions & 72 deletions include/clad/Differentiator/ArrayExpression.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,89 +81,65 @@ class array_expression {
}

std::size_t size() const { return std::max(get_size(l), get_size(r)); }
};

// Operator overload for addition.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryAdd, RE>
operator+(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryAdd, RE>(
*this, r);
}
// A template class to determine whether a given type is array_expression, array
// or array_ref.
template <typename T> class array;
template <typename T> class array_ref;

// Operator overload for multiplication.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryMul, RE>
operator*(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryMul, RE>(
*this, r);
}
template <typename T> struct is_clad_type : std::false_type {};

// Operator overload for subtraction.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinarySub, RE>
operator-(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinarySub, RE>(
*this, r);
}
template <typename LeftExp, typename BinaryOp, typename RightExp>
struct is_clad_type<array_expression<LeftExp, BinaryOp, RightExp>>
: std::true_type {};

// Operator overload for division.
template <typename RE>
array_expression<const array_expression<LeftExp, BinaryOp, RightExp>&,
BinaryDiv, RE>
operator/(const RE& r) const {
return array_expression<
const array_expression<LeftExp, BinaryOp, RightExp>&, BinaryDiv, RE>(
*this, r);
}
};
template <typename T> struct is_clad_type<array<T>> : std::true_type {};

// A class to determine whether a given type is array_expression.
template <typename T> struct is_array_expr : std::false_type {};
template <typename T> struct is_clad_type<array_ref<T>> : std::true_type {};

template <typename LeftExp, typename BinaryOp, typename RightExp>
struct is_array_expr<array_expression<LeftExp, BinaryOp, RightExp>>
: std::true_type {};
// Operator overload for addition, when one of the operands is array_expression,
// array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryAdd, const T2&> operator+(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for addition, when the right operand is an array_expression
// and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<!is_array_expr<T>::value, int>::type = 0>
array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator+(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinaryAdd,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for multiplication, when one of the operands is
// array_expression, array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryMul, const T2&> operator*(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for multiplication, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<!is_array_expr<T>::value, int>::type = 0>
array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator*(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinaryMul,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for subtraction, when one of the operands is
// array_expression, array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinarySub, const T2&> operator-(const T1& l,
const T2& r) {
return {l, r};
}

// Operator overload for subtraction, when the right operand is an
// array_expression and the left operand is a scalar.
template <typename T, typename LeftExp, typename BinaryOp, typename RightExp,
typename std::enable_if<!is_array_expr<T>::value, int>::type = 0>
array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>
operator-(const T& l, const array_expression<LeftExp, BinaryOp, RightExp>& r) {
return array_expression<T, BinarySub,
const array_expression<LeftExp, BinaryOp, RightExp>&>(
l, r);
// Operator overload for division, when one of the operands is array_expression,
// array or array_ref.
template <
typename T1, typename T2,
typename std::enable_if<is_clad_type<T1>::value || is_clad_type<T2>::value,
int>::type = 0>
array_expression<const T1&, BinaryDiv, const T2&> operator/(const T1& l,
const T2& r) {
return {l, r};
}
} // namespace clad
// NOLINTEND(*-pointer-arithmetic)
Expand Down
Loading

0 comments on commit fc48ba5

Please sign in to comment.