Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix so the storage under the value fields are also initialized in setup
kernel

Former-commit-id: 552cfd980e87d76104ea7c7af4fc2f314bea9951
  • Loading branch information
dumerrill committed Aug 7, 2017
1 parent 8e99759 commit e5dc139
Show file tree
Hide file tree
Showing 3 changed files with 2,173 additions and 25 deletions.
71 changes: 47 additions & 24 deletions cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,13 @@ struct ScanTileState<T, true>


// Device storage
TileDescriptor *d_tile_status;

TxnWord *d_tile_descriptors;

/// Constructor
__host__ __device__ __forceinline__
ScanTileState()
:
d_tile_status(NULL)
d_tile_descriptors(NULL)
{}


Expand All @@ -182,7 +181,7 @@ struct ScanTileState<T, true>
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation
{
d_tile_status = reinterpret_cast<TileDescriptor*>(d_temp_storage);
d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage);
return cudaSuccess;
}

Expand All @@ -206,16 +205,22 @@ struct ScanTileState<T, true>
__device__ __forceinline__ void InitializeStatus(int num_tiles)
{
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x;

TxnWord val = TxnWord();
TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val);

if (tile_idx < num_tiles)
{
// Not-yet-set
d_tile_status[TILE_STATUS_PADDING + tile_idx].status = StatusWord(SCAN_TILE_INVALID);
descriptor->status = StatusWord(SCAN_TILE_INVALID);
d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val;
}

if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING))
{
// Padding
d_tile_status[threadIdx.x].status = StatusWord(SCAN_TILE_OOB);
descriptor->status = StatusWord(SCAN_TILE_OOB);
d_tile_descriptors[threadIdx.x] = val;
}
}

Expand All @@ -231,7 +236,7 @@ struct ScanTileState<T, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx), alias);
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}


Expand All @@ -246,7 +251,7 @@ struct ScanTileState<T, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx), alias);
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}

/**
Expand All @@ -257,11 +262,11 @@ struct ScanTileState<T, true>
StatusWord &status,
T &value)
{
TileDescriptor tile_descriptor;
TileDescriptor tile_descriptor;
do
{
__threadfence_block(); // prevent hoisting loads from loop
TxnWord alias = ThreadLoad<LOAD_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx));
TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);

} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));
Expand Down Expand Up @@ -525,14 +530,14 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>


// Device storage
TileDescriptor *d_tile_status;
TxnWord *d_tile_descriptors;


/// Constructor
__host__ __device__ __forceinline__
ReduceByKeyScanTileState()
:
d_tile_status(NULL)
d_tile_descriptors(NULL)
{}


Expand All @@ -543,7 +548,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
void *d_temp_storage, ///< [in] %Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
size_t /*temp_storage_bytes*/) ///< [in] Size in bytes of \t d_temp_storage allocation
{
d_tile_status = reinterpret_cast<TileDescriptor*>(d_temp_storage);
d_tile_descriptors = reinterpret_cast<TxnWord*>(d_temp_storage);
return cudaSuccess;
}

Expand All @@ -566,17 +571,22 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
*/
__device__ __forceinline__ void InitializeStatus(int num_tiles)
{
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x;
int tile_idx = (blockIdx.x * blockDim.x) + threadIdx.x;
TxnWord val = TxnWord();
TileDescriptor *descriptor = reinterpret_cast<TileDescriptor*>(&val);

if (tile_idx < num_tiles)
{
// Not-yet-set
d_tile_status[TILE_STATUS_PADDING + tile_idx].status = StatusWord(SCAN_TILE_INVALID);
descriptor->status = StatusWord(SCAN_TILE_INVALID);
d_tile_descriptors[TILE_STATUS_PADDING + tile_idx] = val;
}

if ((blockIdx.x == 0) && (threadIdx.x < TILE_STATUS_PADDING))
{
// Padding
d_tile_status[threadIdx.x].status = StatusWord(SCAN_TILE_OOB);
descriptor->status = StatusWord(SCAN_TILE_OOB);
d_tile_descriptors[threadIdx.x] = val;
}
}

Expand All @@ -593,7 +603,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx), alias);
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}


Expand All @@ -609,7 +619,7 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>

TxnWord alias;
*reinterpret_cast<TileDescriptor*>(&alias) = tile_descriptor;
ThreadStore<STORE_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx), alias);
ThreadStore<STORE_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx, alias);
}

/**
Expand All @@ -620,16 +630,29 @@ struct ReduceByKeyScanTileState<ValueT, KeyT, true>
StatusWord &status,
KeyValuePairT &value)
{
TxnWord alias = ThreadLoad<LOAD_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx));
TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
// TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
// TileDescriptor tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
//
// while (tile_descriptor.status == SCAN_TILE_INVALID)
// {
// __threadfence_block(); // prevent hoisting loads from loop
//
// alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
// tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
// }
//
// status = tile_descriptor.status;
// value.value = tile_descriptor.value;
// value.key = tile_descriptor.key;

while (tile_descriptor.status == SCAN_TILE_INVALID)
TileDescriptor tile_descriptor;
do
{
__threadfence_block(); // prevent hoisting loads from loop

alias = ThreadLoad<LOAD_CG>(reinterpret_cast<TxnWord*>(d_tile_status + TILE_STATUS_PADDING + tile_idx));
TxnWord alias = ThreadLoad<LOAD_CG>(d_tile_descriptors + TILE_STATUS_PADDING + tile_idx);
tile_descriptor = reinterpret_cast<TileDescriptor&>(alias);
}

} while (WARP_ANY((tile_descriptor.status == SCAN_TILE_INVALID), 0xffffffff));

status = tile_descriptor.status;
value.value = tile_descriptor.value;
Expand Down
Loading

0 comments on commit e5dc139

Please sign in to comment.