Skip to content

Commit

Permalink
Merge pull request NVIDIA#2 from NVlabs/master
Browse files Browse the repository at this point in the history
Sync fork
  • Loading branch information
dumerrill authored Dec 5, 2017
2 parents a6e0232 + 69c3566 commit edf23e7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 267 deletions.
219 changes: 4 additions & 215 deletions cub/agent/agent_spmv_orig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ struct AgentSpmv
#if (CUB_PTX_ARCH >= 520)

/*
OffsetT* s_tile_row_end_offsets = &temp_storage.merge_items[tile_num_nonzeros].row_end_offset;
ValueT* s_tile_nonzeros = &temp_storage.merge_items[0].nonzero;
OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[tile_num_nonzeros].row_end_offset;
ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[0].nonzero;
OffsetT col_indices[ITEMS_PER_THREAD];
ValueT mat_values[ITEMS_PER_THREAD];
Expand Down Expand Up @@ -466,8 +466,8 @@ struct AgentSpmv
*/

OffsetT* s_tile_row_end_offsets = &temp_storage.merge_items[0].row_end_offset;
ValueT* s_tile_nonzeros = &temp_storage.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero;
OffsetT* s_tile_row_end_offsets = &temp_storage.aliasable.merge_items[0].row_end_offset;
ValueT* s_tile_nonzeros = &temp_storage.aliasable.merge_items[tile_num_rows + ITEMS_PER_THREAD].nonzero;

// Gather the nonzeros for the merge tile into shared memory
#pragma unroll
Expand Down Expand Up @@ -640,217 +640,6 @@ struct AgentSpmv
}







/**
* Consume a merge tile, specialized for indirect load of nonzeros
* /
template <typename IsDirectLoadT>
__device__ __forceinline__ KeyValuePairT ConsumeTile1(
int tile_idx,
CoordinateT tile_start_coord,
CoordinateT tile_end_coord,
IsDirectLoadT is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch
{
int tile_num_rows = tile_end_coord.x - tile_start_coord.x;
int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y;
OffsetT* s_tile_row_end_offsets = &temp_storage.merge_items[0].row_end_offset;
int warp_idx = threadIdx.x / WARP_THREADS;
int lane_idx = LaneId();
// Gather the row end-offsets for the merge tile into shared memory
#pragma unroll 1
for (int item = threadIdx.x; item <= tile_num_rows; item += BLOCK_THREADS)
{
s_tile_row_end_offsets[item] = wd_row_end_offsets[tile_start_coord.x + item];
}
CTA_SYNC();
// Search for warp start/end coords
if (lane_idx == 0)
{
MergePathSearch(
OffsetT(warp_idx * ITEMS_PER_WARP), // Diagonal
s_tile_row_end_offsets, // List A
CountingInputIterator<OffsetT>(tile_start_coord.y), // List B
tile_num_rows,
tile_num_nonzeros,
temp_storage.warp_coords[warp_idx]);
CoordinateT last = {tile_num_rows, tile_num_nonzeros};
temp_storage.warp_coords[WARPS] = last;
}
CTA_SYNC();
CoordinateT warp_coord = temp_storage.warp_coords[warp_idx];
CoordinateT warp_end_coord = temp_storage.warp_coords[warp_idx + 1];
OffsetT warp_nonzero_idx = tile_start_coord.y + warp_coord.y;
// Consume whole rows
#pragma unroll 1
for (; warp_coord.x < warp_end_coord.x; ++warp_coord.x)
{
ValueT row_total = 0.0;
OffsetT row_end_offset = s_tile_row_end_offsets[warp_coord.x];
#pragma unroll 1
for (OffsetT nonzero_idx = warp_nonzero_idx + lane_idx;
nonzero_idx < row_end_offset;
nonzero_idx += WARP_THREADS)
{
OffsetT column_idx = wd_column_indices[nonzero_idx];
ValueT value = wd_values[nonzero_idx];
ValueT vector_value = wd_vector_x[column_idx];
row_total += value * vector_value;
}
// Warp reduce
row_total = WarpReduceT(temp_storage.warp_reduce[warp_idx]).Sum(row_total);
// Output
if (lane_idx == 0)
{
spmv_params.d_vector_y[tile_start_coord.x + warp_coord.x] = row_total;
}
warp_nonzero_idx = row_end_offset;
}
// Consume partial portion of thread's last row
if (warp_nonzero_idx < tile_start_coord.y + warp_end_coord.y)
{
ValueT row_total = 0.0;
for (OffsetT nonzero_idx = warp_nonzero_idx + lane_idx;
nonzero_idx < tile_start_coord.y + warp_end_coord.y;
nonzero_idx += WARP_THREADS)
{
OffsetT column_idx = wd_column_indices[nonzero_idx];
ValueT value = wd_values[nonzero_idx];
ValueT vector_value = wd_vector_x[column_idx];
row_total += value * vector_value;
}
// Warp reduce
row_total = WarpReduceT(temp_storage.warp_reduce[warp_idx]).Sum(row_total);
// Output
if (lane_idx == 0)
{
spmv_params.d_vector_y[tile_start_coord.x + warp_coord.x] = row_total;
}
}
// Return the tile's running carry-out
KeyValuePairT tile_carry(tile_num_rows, 0.0);
return tile_carry;
}
*/







/**
* Consume a merge tile, specialized for indirect load of nonzeros
* /
__device__ __forceinline__ KeyValuePairT ConsumeTile2(
int tile_idx,
CoordinateT tile_start_coord,
CoordinateT tile_end_coord,
Int2Type<false> is_direct_load) ///< Marker type indicating whether to load nonzeros directly during path-discovery or beforehand in batch
{
int tile_num_rows = tile_end_coord.x - tile_start_coord.x;
int tile_num_nonzeros = tile_end_coord.y - tile_start_coord.y;
ValueT* s_tile_nonzeros = &temp_storage.merge_items[0].nonzero;
ValueT nonzeros[ITEMS_PER_THREAD];
// Gather the nonzeros for the merge tile into shared memory
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int nonzero_idx = threadIdx.x + (ITEM * BLOCK_THREADS);
nonzero_idx = CUB_MIN(nonzero_idx, tile_num_nonzeros - 1);
OffsetT column_idx = wd_column_indices[tile_start_coord.y + nonzero_idx];
ValueT value = wd_values[tile_start_coord.y + nonzero_idx];
ValueT vector_value = spmv_params.t_vector_x[column_idx];
#if (CUB_PTX_ARCH >= 350)
vector_value = wd_vector_x[column_idx];
#endif
nonzeros[ITEM] = value * vector_value;
}
// Exchange striped->blocked
BlockExchangeT(temp_storage.exchange).StripedToBlocked(nonzeros);
CTA_SYNC();
// Compute an inclusive prefix sum
BlockPrefixSumT(temp_storage.prefix_sum).InclusiveSum(nonzeros, nonzeros);
CTA_SYNC();
if (threadIdx.x == 0)
s_tile_nonzeros[0] = 0.0;
// Scatter back to smem
#pragma unroll
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD; ++ITEM)
{
int item_idx = (threadIdx.x * ITEMS_PER_THREAD) + ITEM + 1;
s_tile_nonzeros[item_idx] = nonzeros[ITEM];
}
CTA_SYNC();
// Gather the row end-offsets for the merge tile into shared memory
#pragma unroll 1
for (int item = threadIdx.x; item < tile_num_rows; item += BLOCK_THREADS)
{
OffsetT start = CUB_MAX(wd_row_end_offsets[tile_start_coord.x + item - 1], tile_start_coord.y);
OffsetT end = wd_row_end_offsets[tile_start_coord.x + item];
start -= tile_start_coord.y;
end -= tile_start_coord.y;
ValueT row_partial = s_tile_nonzeros[end] - s_tile_nonzeros[start];
spmv_params.d_vector_y[tile_start_coord.x + item] = row_partial;
}
// Get the tile's carry-out
KeyValuePairT tile_carry;
if (threadIdx.x == 0)
{
tile_carry.key = tile_num_rows;
OffsetT start = CUB_MAX(wd_row_end_offsets[tile_end_coord.x - 1], tile_start_coord.y);
start -= tile_start_coord.y;
OffsetT end = tile_num_nonzeros;
tile_carry.value = s_tile_nonzeros[end] - s_tile_nonzeros[start];
}
// Return the tile's running carry-out
return tile_carry;
}
*/


/**
* Consume input tile
*/
Expand Down
88 changes: 36 additions & 52 deletions cub/device/dispatch/dispatch_spmv_orig.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -415,12 +415,41 @@ struct DispatchSpmv
};


/// SM60
struct Policy600
{
typedef AgentSpmvPolicy<
(sizeof(ValueT) > 4) ? 64 : 128,
(sizeof(ValueT) > 4) ? 5 : 7,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
LOAD_DEFAULT,
false,
BLOCK_SCAN_WARP_SCANS>
SpmvPolicyT;


typedef AgentSegmentFixupPolicy<
128,
3,
BLOCK_LOAD_DIRECT,
LOAD_LDG,
BLOCK_SCAN_WARP_SCANS>
SegmentFixupPolicyT;
};



//---------------------------------------------------------------------
// Tuning policies of current PTX compiler pass
//---------------------------------------------------------------------

#if (CUB_PTX_ARCH >= 500)
#if (CUB_PTX_ARCH >= 600)
typedef Policy600 PtxPolicy;

#elif (CUB_PTX_ARCH >= 500)
typedef Policy500 PtxPolicy;

#elif (CUB_PTX_ARCH >= 370)
Expand Down Expand Up @@ -468,7 +497,12 @@ struct DispatchSpmv
#else

// We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version
if (ptx_version >= 500)
if (ptx_version >= 600)
{
spmv_config.template Init<typename Policy600::SpmvPolicyT>();
segment_fixup_config.template Init<typename Policy600::SegmentFixupPolicyT>();
}
else if (ptx_version >= 500)
{
spmv_config.template Init<typename Policy500::SpmvPolicyT>();
segment_fixup_config.template Init<typename Policy500::SegmentFixupPolicyT>();
Expand Down Expand Up @@ -786,56 +820,6 @@ struct DispatchSpmv
DeviceSegmentFixupKernel<PtxSegmentFixupPolicy, KeyValuePairT*, ValueT*, OffsetT, ScanTileStateT>,
spmv_config, segment_fixup_config))) break;

/*
// Dispatch
if (spmv_params.beta == 0.0)
{
if (spmv_params.alpha == 1.0)
{
// Dispatch y = A*x
if (CubDebug(error = Dispatch(
d_temp_storage, temp_storage_bytes, spmv_params, stream, debug_synchronous,
DeviceSpmv1ColKernel<PtxSpmvPolicyT, ValueT, OffsetT>,
DeviceSpmvSearchKernel<PtxSpmvPolicyT, OffsetT, CoordinateT, SpmvParamsT>,
DeviceSpmvKernel<PtxSpmvPolicyT, ScanTileStateT, ValueT, OffsetT, CoordinateT, false, false>,
DeviceSegmentFixupKernel<PtxSegmentFixupPolicy, KeyValuePairT*, ValueT*, OffsetT, ScanTileStateT>,
spmv_config, segment_fixup_config))) break;
}
else
{
// Dispatch y = alpha*A*x
if (CubDebug(error = Dispatch(
d_temp_storage, temp_storage_bytes, spmv_params, stream, debug_synchronous,
DeviceSpmvSearchKernel<PtxSpmvPolicyT, ScanTileStateT, OffsetT, CoordinateT, SpmvParamsT>,
DeviceSpmvKernel<PtxSpmvPolicyT, ValueT, OffsetT, CoordinateT, true, false>,
DeviceSegmentFixupKernel<PtxSegmentFixupPolicy, KeyValuePairT*, ValueT*, OffsetT, ScanTileStateT>,
spmv_config, segment_fixup_config))) break;
}
}
else
{
if (spmv_params.alpha == 1.0)
{
// Dispatch y = A*x + beta*y
if (CubDebug(error = Dispatch(
d_temp_storage, temp_storage_bytes, spmv_params, stream, debug_synchronous,
DeviceSpmvSearchKernel<PtxSpmvPolicyT, ScanTileStateT, OffsetT, CoordinateT, SpmvParamsT>,
DeviceSpmvKernel<PtxSpmvPolicyT, ValueT, OffsetT, CoordinateT, false, true>,
DeviceSegmentFixupKernel<PtxSegmentFixupPolicy, KeyValuePairT*, ValueT*, OffsetT, ScanTileStateT>,
spmv_config, segment_fixup_config))) break;
}
else
{
// Dispatch y = alpha*A*x + beta*y
if (CubDebug(error = Dispatch(
d_temp_storage, temp_storage_bytes, spmv_params, stream, debug_synchronous,
DeviceSpmvSearchKernel<PtxSpmvPolicyT, ScanTileStateT, OffsetT, CoordinateT, SpmvParamsT>,
DeviceSpmvKernel<PtxSpmvPolicyT, ValueT, OffsetT, CoordinateT, true, true>,
DeviceSegmentFixupKernel<PtxSegmentFixupPolicy, KeyValuePairT*, ValueT*, OffsetT, ScanTileStateT>,
spmv_config, segment_fixup_config))) break;
}
}
*/
}
while (0);

Expand Down

0 comments on commit edf23e7

Please sign in to comment.