Skip to content

Commit

Permalink
[RF] Refactor RooBatchCompute to simplify Batches classes
Browse files Browse the repository at this point in the history
The Batches classes in RooBatchCompute have only public data members
anyway. It's not necessary to have extra member functions for changing
them.
  • Loading branch information
guitargeek committed Mar 5, 2024
1 parent fdb6e96 commit 0f96696
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 406 deletions.
47 changes: 7 additions & 40 deletions roofit/batchcompute/src/Batches.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,65 +23,32 @@ so that they can contain data for every kind of compute function.
#ifndef ROOFIT_BATCHCOMPUTE_BATCHES_H
#define ROOFIT_BATCHCOMPUTE_BATCHES_H

#include <RooBatchComputeTypes.h>

#include <cstdint>

namespace RooBatchCompute {

namespace RF_ARCH {

class Batch {
public:
const double *__restrict _array = nullptr;
bool _isVector = false;

Batch() = default;
inline Batch(InputArr array, bool isVector) : _array{array}, _isVector{isVector} {}

__roodevice__ constexpr bool isItVector() const { return _isVector; }
inline void set(InputArr array, bool isVector)
{
_array = array;
_isVector = isVector;
}
inline void advance(std::size_t _nEvents) { _array += _isVector * _nEvents; }
#ifdef __CUDACC__
__roodevice__ constexpr double operator[](std::size_t i) const noexcept { return _isVector ? _array[i] : _array[0]; }
__device__ constexpr double operator[](std::size_t i) const noexcept { return _isVector ? _array[i] : _array[0]; }
#else
constexpr double operator[](std::size_t i) const noexcept { return _array[i]; }
#endif // #ifdef __CUDACC__
};

/////////////////////////////////////////////////////////////////////////////////////////////////////////

class Batches {
public:
Batch *_arrays = nullptr;
double *_extraArgs = nullptr;
std::size_t _nEvents = 0;
std::size_t _nBatches = 0;
std::size_t _nExtraArgs = 0;
RestrictArr _output = nullptr;

__roodevice__ std::size_t getNEvents() const { return _nEvents; }
__roodevice__ std::size_t getNExtraArgs() const { return _nExtraArgs; }
__roodevice__ double extraArg(std::size_t i) const { return _extraArgs[i]; }
__roodevice__ void setExtraArg(std::size_t i, double val) { _extraArgs[i] = val; }
__roodevice__ Batch operator[](int batchIdx) const { return _arrays[batchIdx]; }
inline void setNEvents(std::size_t n) { _nEvents = n; }
inline void advance(std::size_t nEvents)
{
for (std::size_t i = 0; i < _nBatches; i++)
_arrays[i].advance(nEvents);
_output += nEvents;
}
Batch *args = nullptr;
double *extra;
std::size_t nEvents = 0;
std::size_t nBatches = 0;
std::size_t nExtra = 0;
RestrictArr output = nullptr;
};

// Defines the actual argument type of the compute function.
using BatchesHandle = Batches &;

} // End namespace RF_ARCH
} // end namespace RooBatchCompute

#endif // #ifdef ROOFIT_BATCHCOMPUTE_BATCHES_H
Loading

0 comments on commit 0f96696

Please sign in to comment.