Skip to content

Commit

Permalink
Merge pull request #874 from fragcolor-xyz/guus/compose-var-type-check
Browse files Browse the repository at this point in the history
Guus/compose var type check
  • Loading branch information
guusw authored Dec 5, 2023
2 parents b44c7a3 + 38be987 commit bd95479
Show file tree
Hide file tree
Showing 29 changed files with 1,872 additions and 254 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test-linux-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,20 @@ jobs:
echo "\n"
echo "Running graphics test scripts"
for i in $(find shards/tests -name 'gfx*.shs');
for i in $(find shards/tests -maxdepth 1 -name 'gfx*.shs');
do
echo "Running $i"
build/shards new "$i"
done
for i in $(find shards/tests -name 'gfx*.shs');
for i in $(find shards/tests -maxdepth 1 -name 'gfx*.shs');
do
echo "Running $i"
build/shards new "$i"
done
echo "\n"
echo "Running graphics test scripts"
for i in $(find shards/tests -name 'gfx*.edn');
for i in $(find shards/tests -maxdepth 1 -name 'gfx*.edn');
do
echo "Running $i"
build/shards "$i"
Expand Down
4 changes: 2 additions & 2 deletions include/shards/shards.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ struct Types {
return *this;
}

operator SHTypesInfo() {
operator SHTypesInfo() const {
if (_types.size() > 0) {
return SHTypesInfo{&_types[0], (uint32_t)_types.size(), 0};
return SHTypesInfo{&const_cast<Types*>(this)->_types[0], (uint32_t)_types.size(), 0};
} else {
return SHTypesInfo{nullptr, 0, 0};
}
Expand Down
26 changes: 26 additions & 0 deletions run_tests
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
SHARDS=build/clang-x86_64-pc-windows-msvc/Debug/shards.exe
OK_TAG_DIR=shards/tests/tag_ok
ERR_TAG_DIR=shards/tests/tag_err
IGNORE_TAG_DIR=shards/tests/tag_ignore
mkdir -p $OK_TAG_DIR
mkdir -p $ERR_TAG_DIR
find shards/tests -maxdepth 1 -name "*.shs" | while read -r test; do
# base=`basename "$test"`
OK_TAG=$OK_TAG_DIR/${test##*/}
ERR_TAG=$ERR_TAG_DIR/${test##*/}
IGNORE_TAG=$IGNORE_TAG_DIR/${test##*/}
if [[ -f $OK_TAG || -f $IGNORE_TAG ]]; then
continue
fi
echo "Running $test"
"$SHARDS" new "$test" 2> $ERR_TAG
RESULT=$?
if [[ $RESULT -eq 0 ]]; then
rm -f $OK_TAG
mv $ERR_TAG $OK_TAG
else
echo "!Test failed"
fi
done

echo "Done"
22 changes: 21 additions & 1 deletion shards/core/exposed_type_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ struct RequiredContextVariable {
}
}

void cleanup(SHContext* context = nullptr) {
void cleanup(SHContext *context = nullptr) {
if (variable) {
shards::releaseVariable(variable);
variable = nullptr;
Expand Down Expand Up @@ -176,6 +176,26 @@ inline void getObjectTypes(std::vector<SHTypeInfo> &out, const SHTypeInfo &type)
}
}

inline bool hasContextVariables(const SHTypeInfo &type) {
switch (type.basicType) {
case SHType::ContextVar:
return true;
case SHType::Seq:
for (auto &t : type.seqTypes)
if (hasContextVariables(t))
return true;
break;
case SHType::Table:
for (auto &t : type.table.types)
if (hasContextVariables(t))
return true;
break;
default:
break;
}
return false;
}

} // namespace shards

#endif /* A16CC8A4_FBC4_4500_BE1D_F565963C9C16 */
29 changes: 26 additions & 3 deletions shards/core/foundation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "ops_internal.hpp"

#include "spdlog/spdlog.h"
#include "type_matcher.hpp"

#include <algorithm>
#include <atomic>
Expand Down Expand Up @@ -180,7 +181,8 @@ template <typename T, typename Resource = bumping_memory_resource<sizeof(T) * 32
};

void freeDerivedInfo(SHTypeInfo info);
SHTypeInfo deriveTypeInfo(const SHVar &value, const SHInstanceData &data, std::vector<SHExposedTypeInfo> *expInfo = nullptr);
SHTypeInfo deriveTypeInfo(const SHVar &value, const SHInstanceData &data, std::vector<SHExposedTypeInfo> *expInfo = nullptr,
bool resolveContextVariables = true);
SHTypeInfo cloneTypeInfo(const SHTypeInfo &other);

uint64_t deriveTypeHash(const SHVar &value);
Expand All @@ -189,8 +191,9 @@ uint64_t deriveTypeHash(const SHTypeInfo &value);
struct TypeInfo {
TypeInfo() {}

TypeInfo(const SHVar &var, const SHInstanceData &data, std::vector<SHExposedTypeInfo> *expInfo = nullptr) {
_info = deriveTypeInfo(var, data, expInfo);
TypeInfo(const SHVar &var, const SHInstanceData &data, std::vector<SHExposedTypeInfo> *expInfo = nullptr,
bool resolveContextVariables = true) {
_info = deriveTypeInfo(var, data, expInfo, resolveContextVariables);
}

TypeInfo(const SHTypeInfo &info) { _info = cloneTypeInfo(info); }
Expand Down Expand Up @@ -1513,6 +1516,7 @@ inline std::optional<SHExposedTypeInfo> findExposedVariable(const SHExposedTypes
// Collects all ContextVar references
inline void collectRequiredVariables(const SHExposedTypesInfo &exposed, ExposedInfo &out, const SHVar &var) {
using namespace std::literals;

switch (var.valueType) {
case SHType::ContextVar: {
auto sv = SHSTRVIEW(var);
Expand All @@ -1538,6 +1542,25 @@ inline void collectRequiredVariables(const SHExposedTypesInfo &exposed, ExposedI
}
}

inline bool collectRequiredVariables(const SHExposedTypesInfo &exposed, ExposedInfo &out, const SHVar &var,
SHTypesInfo validTypes, const char *debugTag) {
SHInstanceData data{.shared = exposed};
std::vector<SHExposedTypeInfo> expInfo;
TypeInfo ti(var, data, &expInfo, false);
for (auto &type : validTypes) {
if (TypeMatcher{.isParameter = true, .relaxEmptyTableCheck = true, .relaxEmptySeqCheck = expInfo.empty(), .checkVarTypes = true}.match(ti, type)) {
for (auto &it : expInfo) {
out.push_back(it);
}
return true;
}
}
auto msg = fmt::format("No matching variable found for parameter {}, was: {}, expected any of {}", debugTag, (SHTypeInfo &)ti,
validTypes);
SHLOG_ERROR("{}", msg);
throw ComposeError(msg);
}

template <typename... TArgs>
inline void collectAllRequiredVariables(const SHExposedTypesInfo &exposed, ExposedInfo &out, TArgs &&...args) {
(collectRequiredVariables(exposed, out, std::forward<TArgs>(args)), ...);
Expand Down
49 changes: 35 additions & 14 deletions shards/core/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "stddef.h"
#include "foundation.hpp"
#include "self_macro.h"
#include "exposed_type_utils.hpp"
#include <type_traits>
#include <shards/shardwrapper.hpp>

Expand All @@ -27,7 +28,8 @@ struct IterableParam {

void (*setParam)(void *varPtr, SHVar var){};
SHVar (*getParam)(void *varPtr){};
void (*collectRequirements)(const SHExposedTypesInfo &exposed, ExposedInfo &out, void *varPtr){};
void (*collectRequirements)(const shards::IterableParam &param, const SHExposedTypesInfo &exposed, ExposedInfo &out,
void *varPtr){};
void (*warmup)(void *varPtr, SHContext *ctx){};
void (*cleanup)(void *varPtr, SHContext *ctx){};

Expand All @@ -39,12 +41,28 @@ struct IterableParam {

template <typename T>
static IterableParam createWithVarInterface(void *(*resolveParamInShard)(void *), const ParameterInfo *paramInfo) {
IterableParam result{.resolveParamInShard = resolveParamInShard,
.paramInfo = paramInfo,
.setParam = [](void *varPtr, SHVar var) { *((T *)varPtr) = var; },
.getParam = [](void *varPtr) -> SHVar { return *((T *)varPtr); },
.collectRequirements = [](const SHExposedTypesInfo &exposed, ExposedInfo &out,
void *varPtr) { collectRequiredVariables(exposed, out, *((T *)varPtr)); }};

IterableParam result{
.resolveParamInShard = resolveParamInShard,
.paramInfo = paramInfo,
.setParam = [](void *varPtr, SHVar var) { *((T *)varPtr) = var; },
.getParam = [](void *varPtr) -> SHVar { return *((T *)varPtr); },
.collectRequirements =
[](const shards::IterableParam &param, const SHExposedTypesInfo &exposed, ExposedInfo &out, void *varPtr) {
collectRequiredVariables(exposed, out, *((T *)varPtr), SHTypesInfo(param.paramInfo->_types),
param.paramInfo->_name);
}};

bool canPossiblyHaveContextVariables = false;
for (auto &type : paramInfo->_types._types) {
if (hasContextVariables(type)) {
canPossiblyHaveContextVariables = true;
break;
}
}
if (!canPossiblyHaveContextVariables) {
result.collectRequirements = nullptr;
}

if constexpr (has_warmup<T>::value) {
result.warmup = [](void *varPtr, SHContext *ctx) { ((T *)varPtr)->warmup(ctx); };
Expand Down Expand Up @@ -111,13 +129,16 @@ struct IterableParam {
// PARAM_COMPOSE_REQUIRED_VARIABLES(data);
// return outputTypes().elements[0];
// }
#define PARAM_COMPOSE_REQUIRED_VARIABLES(__data) \
{ \
size_t numParams; \
const shards::IterableParam *params = getIterableParams(numParams); \
_requiredVariables.clear(); \
for (size_t i = 0; i < numParams; i++) \
params[i].collectRequirements(__data.shared, _requiredVariables, params[i].resolveParamInShard(this)); \
#define PARAM_COMPOSE_REQUIRED_VARIABLES(__data) \
{ \
size_t numParams; \
const shards::IterableParam *params = getIterableParams(numParams); \
_requiredVariables.clear(); \
for (size_t i = 0; i < numParams; i++) { \
if (params[i].collectRequirements) { \
params[i].collectRequirements(params[i], __data.shared, _requiredVariables, params[i].resolveParamInShard(this)); \
} \
} \
}

// Implements setParam()/getParam()
Expand Down
Loading

0 comments on commit bd95479

Please sign in to comment.