Skip to content

Commit

Permalink
vsmigx/vs_migraphx.cpp: add missing input format check
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Oct 6, 2024
1 parent 935728d commit 07158f3
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions vsmigx/vs_migraphx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,43 @@ int numPlanes(
return num_planes;
}

// 0: integer, 1: float, -1: unknown
static inline
int getSampleType(migraphx_shape_datatype_t type) noexcept {
switch (type) {
case migraphx_shape_uint8_type:
case migraphx_shape_uint16_type:
case migraphx_shape_uint32_type:
case migraphx_shape_uint64_type:
return 0;
case migraphx_shape_half_type:
case migraphx_shape_float_type:
case migraphx_shape_double_type:
return 1;
default:
return -1;
}
}

static inline
int getBytesPerSample(migraphx_shape_datatype_t type) noexcept {
switch (type) {
case migraphx_shape_uint8_type:
return 1;
case migraphx_shape_half_type:
case migraphx_shape_uint16_type:
return 2;
case migraphx_shape_float_type:
case migraphx_shape_uint32_type:
return 4;
case migraphx_shape_double_type:
case migraphx_shape_uint64_type:
return 8;
default:
return 0;
}
}

static inline void VS_CC getDeviceProp(
const VSMap *in, VSMap *out, void *userData,
VSCore *core, const VSAPI *vsapi
Expand Down Expand Up @@ -778,6 +815,12 @@ static void VS_CC vsMIGXCreate(
if (type != migraphx_shape_float_type && type != migraphx_shape_half_type) {
return set_error("input type must be float or half");
}
if (in_vis[0]->format->sampleType != getSampleType(type)) {
return set_error("sample type mismatch");
}
if (in_vis[0]->format->bytesPerSample != getBytesPerSample(type)) {
return set_error("bytes per sample mismatch");
}
const size_t * lengths;
size_t ndim;
checkError(migraphx_shape_lengths(&lengths, &ndim, input_shape));
Expand Down

0 comments on commit 07158f3

Please sign in to comment.