Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCLomatic] Refine migration of dim3 with helper function. #2025

Merged
merged 7 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions clang/lib/DPCT/APINamesCooperativeGroups.inc
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,12 @@ MEMBER_CALL_FACTORY_ENTRY("cooperative_groups::__v1::thread_group.num_threads",
MemberExprBase(), false, "get_local_linear_range")
MEMBER_CALL_FACTORY_ENTRY("cooperative_groups::__v1::thread_group.get_type",
MemberExprBase(), false, "get_type")
MEMBER_CALL_FACTORY_ENTRY("cooperative_groups::__v1::thread_block.group_index", MemberExprBase(),
false, "get_group_id")
MEMBER_CALL_FACTORY_ENTRY("cooperative_groups::__v1::thread_block.thread_index", MemberExprBase(),
false, "get_local_id")
CALL_FACTORY_ENTRY("cooperative_groups::__v1::thread_block.group_index",
CALL(MapNames::getDpctNamespace() + "dim3",
MEMBER_CALL(MemberExprBase(), false, "get_group_id")))
CALL_FACTORY_ENTRY("cooperative_groups::__v1::thread_block.thread_index",
CALL(MapNames::getDpctNamespace() + "dim3",
MEMBER_CALL(MemberExprBase(), false, "get_local_id")))

CONDITIONAL_FACTORY_ENTRY(
UseNonUniformGroups,
Expand Down
198 changes: 2 additions & 196 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
MF.addMatcher(
typeLoc(
loc(qualType(hasDeclaration(namedDecl(hasAnyName(
"cudaError", "curandStatus", "cublasStatus", "CUstream",
"dim3", "cudaError", "curandStatus", "cublasStatus", "CUstream",
"CUstream_st", "thrust::complex", "thrust::device_vector",
"thrust::device_ptr", "thrust::device_reference",
"thrust::host_vector", "cublasHandle_t", "CUevent_st", "__half",
Expand Down Expand Up @@ -3002,198 +3002,6 @@ void VectorTypeOperatorRule::runRule(const MatchFinder::MatchResult &Result) {

REGISTER_RULE(VectorTypeOperatorRule, PassKind::PK_Migration)

void ReplaceDim3CtorRule::registerMatcher(MatchFinder &MF) {
// Find dim3 constructors which are part of different casts (representing
// different syntaxes). This includes copy constructors. All constructors
// will be visited once.
MF.addMatcher(cxxConstructExpr(hasType(namedDecl(hasName("dim3"))),
argumentCountIs(1),
unless(hasAncestor(cxxConstructExpr(
hasType(namedDecl(hasName("dim3")))))))
.bind("dim3Top"),
this);

MF.addMatcher(cxxConstructExpr(
hasType(namedDecl(hasName("dim3"))), argumentCountIs(3),
anyOf(hasParent(varDecl()), hasParent(exprWithCleanups())),
unless(hasParent(initListExpr())),
unless(hasAncestor(
cxxConstructExpr(hasType(namedDecl(hasName("dim3")))))))
.bind("dim3CtorDecl"),
this);

MF.addMatcher(
cxxConstructExpr(hasType(namedDecl(hasName("dim3"))), argumentCountIs(3),
// skip fields in a struct. The source loc is
// messed up (points to the start of the struct)
unless(hasParent(initListExpr())),
unless(hasAncestor(cxxRecordDecl())),
unless(hasParent(varDecl())),
unless(hasParent(exprWithCleanups())),
unless(hasAncestor(cxxConstructExpr(
hasType(namedDecl(hasName("dim3")))))))
.bind("dim3CtorNoDecl"),
this);

MF.addMatcher(
typeLoc(loc(qualType(hasDeclaration(anyOf(
namedDecl(hasAnyName("dim3")),
typedefDecl(hasAnyName("dim3")))))))
.bind("dim3Type"),
this);
}

ReplaceDim3Ctor *ReplaceDim3CtorRule::getReplaceDim3Modification(
const MatchFinder::MatchResult &Result) {
if (auto Ctor = getNodeAsType<CXXConstructExpr>(Result, "dim3CtorDecl")) {
if(getParentKernelCall(Ctor))
return nullptr;
// dim3 a; or dim3 a(1);
return new ReplaceDim3Ctor(Ctor, true /*isDecl*/);
} else if (auto Ctor =
getNodeAsType<CXXConstructExpr>(Result, "dim3CtorNoDecl")) {
if(getParentKernelCall(Ctor))
return nullptr;
// deflt = dim3(3);
return new ReplaceDim3Ctor(Ctor, false /*isDecl*/);
} else if (auto Ctor = getNodeAsType<CXXConstructExpr>(Result, "dim3Top")) {
if(getParentKernelCall(Ctor))
return nullptr;
// dim3 d3_6_3 = dim3(ceil(test.x + NUM), NUM + test.y, NUM + test.z + NUM);
if (auto A = ReplaceDim3Ctor::getConstructExpr(Ctor->getArg(0))) {
// strip the top CXXConstructExpr, if there's a CXXConstructExpr further
// down
return new ReplaceDim3Ctor(Ctor, A);
} else {
// Copy constructor case: dim3 a(copyfrom)
// No replacements are needed
return nullptr;
}
}

return nullptr;
}

void ReplaceDim3CtorRule::runRule(const MatchFinder::MatchResult &Result) {
ReplaceDim3Ctor *R = getReplaceDim3Modification(Result);
if (R) {
emplaceTransformation(R);
}

if (auto TL = getNodeAsType<TypeLoc>(Result, "dim3Type")) {
if (TL->getBeginLoc().isInvalid())
return;

auto BeginLoc =
getDefinitionRange(TL->getBeginLoc(), TL->getEndLoc()).getBegin();
SourceManager *SM = Result.SourceManager;

// WA for concatenated macro token
if (SM->isWrittenInScratchSpace(SM->getSpellingLoc(TL->getBeginLoc()))) {
BeginLoc = SM->getExpansionLoc(TL->getBeginLoc());
}

Token Tok;
auto LOpts = Result.Context->getLangOpts();
Lexer::getRawToken(BeginLoc, Tok, *SM, LOpts, true);
if (Tok.isAnyIdentifier()) {
if (TL->getType()->isElaboratedTypeSpecifier()) {
// To handle case like "struct cudaExtent extent;"
auto ETC = TL->getUnqualifiedLoc().getAs<ElaboratedTypeLoc>();
auto NTL = ETC.getNamedTypeLoc();

if (NTL.getTypeLocClass() == clang::TypeLoc::Record) {
auto TSL = NTL.getUnqualifiedLoc().getAs<RecordTypeLoc>();

const std::string TyName =
dpct::DpctGlobalInfo::getTypeName(TSL.getType());
std::string Str =
MapNames::findReplacedName(MapNames::TypeNamesMap, TyName);
insertHeaderForTypeRule(TyName, BeginLoc);
requestHelperFeatureForTypeNames(TyName);

if (!Str.empty()) {
emplaceTransformation(
new ReplaceToken(BeginLoc, TSL.getEndLoc(), std::move(Str)));
return;
}
}
}

std::string TypeName = Tok.getRawIdentifier().str();
std::string Str =
MapNames::findReplacedName(MapNames::TypeNamesMap, TypeName);
insertHeaderForTypeRule(TypeName, BeginLoc);
requestHelperFeatureForTypeNames(TypeName);
if (auto VD = DpctGlobalInfo::findAncestor<VarDecl>(TL)) {
auto TypeStr = VD->getType().getAsString();
if (VD->getKind() == Decl::Var && TypeStr == "dim3") {
std::string Replacement;
std::string ReplacedType = "range";
llvm::raw_string_ostream OS(Replacement);
DpctGlobalInfo::printCtadClass(
OS, buildString(MapNames::getClNamespace(), ReplacedType), 3);
Str = OS.str();
}
}

if (!Str.empty()) {
SrcAPIStaticsMap[TypeName]++;
emplaceTransformation(new ReplaceToken(BeginLoc, std::move(Str)));
return;
}
}
}
}

REGISTER_RULE(ReplaceDim3CtorRule, PassKind::PK_Migration)

// rule for dim3 types member fields replacements.
void Dim3MemberFieldsRule::registerMatcher(MatchFinder &MF) {
// dim3->x/y/z => (*dim3)[0]/[1]/[2]
// dim3.x/y/z => dim3[0]/[1]/[2]
// int64_t{dim3->x/y/z} => int64_t((*dim3)[0]/[1]/[2])
// int64_t{dim3.x/y/z} => int64_t(dim3[0]/[1]/[2])
auto Dim3MemberExpr = [&]() {
return memberExpr(anyOf(
has(implicitCastExpr(hasType(pointsTo(typedefDecl(hasName("dim3")))))),
hasObjectExpression(hasType(qualType(hasCanonicalType(
recordType(hasDeclaration(cxxRecordDecl(hasName("dim3"))))))))));
};
MF.addMatcher(Dim3MemberExpr().bind("Dim3MemberExpr"), this);
MF.addMatcher(
cxxFunctionalCastExpr(
allOf(hasTypeLoc(loc(isSignedInteger())),
hasDescendant(
initListExpr(hasInit(0, ignoringImplicit(Dim3MemberExpr())))
.bind("InitListExpr")))),
this);
}

void Dim3MemberFieldsRule::runRule(const MatchFinder::MatchResult &Result) {
// E.g.
// dim3 *pd3, d3;
// pd3->z; d3.z;
// int64_t{d3.x}, int64_t{pd3->x};
// will migrate to:
// (*pd3)[0]; d3[0];
// sycl::range<3> *pd3, d3;
// int64_t(d3[0]), int64_t((*pd3)[0]);
ExprAnalysis EA;
if (const auto *ILE = getNodeAsType<InitListExpr>(Result, "InitListExpr")) {
EA.analyze(ILE);
} else if (const auto *ME =
getNodeAsType<MemberExpr>(Result, "Dim3MemberExpr")) {
EA.analyze(ME);
} else {
return;
}
emplaceTransformation(EA.getReplacement());
EA.applyAllSubExprRepl();
}

REGISTER_RULE(Dim3MemberFieldsRule, PassKind::PK_Migration)

void DeviceInfoVarRule::registerMatcher(MatchFinder &MF) {
MF.addMatcher(
memberExpr(
Expand Down Expand Up @@ -11888,9 +11696,7 @@ void MathFunctionsRule::registerMatcher(MatchFinder &MF) {
internal::Matcher<NamedDecl>(
new internal::HasNameMatcher(MathFunctionsCallExpr)),
anyOf(unless(hasDeclContext(namespaceDecl(anything()))),
hasDeclContext(namespaceDecl(hasName("std")))))),
unless(hasAncestor(
cxxConstructExpr(hasType(typedefDecl(hasName("dim3")))))))
hasDeclContext(namespaceDecl(hasName("std")))))))
.bind("math"),
this);

Expand Down
16 changes: 0 additions & 16 deletions clang/lib/DPCT/ASTTraversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -580,22 +580,6 @@ class VectorTypeOperatorRule
static const char NamespaceName[];
};

class ReplaceDim3CtorRule : public NamedMigrationRule<ReplaceDim3CtorRule> {
ReplaceDim3Ctor *getReplaceDim3Modification(
const ast_matchers::MatchFinder::MatchResult &Result);

public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

/// Migration rule for dim3 types member fields replacements.
class Dim3MemberFieldsRule : public NamedMigrationRule<Dim3MemberFieldsRule> {
public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

class CudaExtentRule : public NamedMigrationRule<CudaExtentRule> {
CharSourceRange getConstructorRange(const CXXConstructExpr *Ctor);
void replaceConstructor(const CXXConstructExpr *Ctor);
Expand Down
Loading
Loading