From 2b9e788c764d8014dbd11bdf7e4985b5d92a3d9a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 6 Feb 2022 00:03:16 -0500 Subject: [PATCH] Generalize classify arguments function --- src/interface.jl | 3 ++- src/irgen.jl | 7 +++---- src/spirv.jl | 3 ++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 42f839e3..fda6a49f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -182,7 +182,8 @@ function process_entry!(@nospecialize(job::CompilerJob), mod::LLVM.Module, if job.source.kernel # pass all bitstypes by value; by default Julia passes aggregates by reference # (this improves performance, and is mandated by certain back-ends like SPIR-V). - args = classify_arguments(job, eltype(llvmtype(entry))) + source_sig = Base.signature_type(job.source.f, job.source.tt)::Type + args = classify_arguments(source_sig, eltype(llvmtype(entry))) for arg in args if arg.cc == BITS_REF attr = if LLVM.version() >= v"12" diff --git a/src/irgen.jl b/src/irgen.jl index 4d527aa6..65c48e8c 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -287,10 +287,8 @@ end GHOST # not passed end -function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType) - source_sig = Base.signature_type(job.source.f, job.source.tt)::Type +function classify_arguments(source_sig::Type, codegen_ft::LLVM.FunctionType) source_types = [source_sig.parameters...] - codegen_types = parameters(codegen_ft) args = [] @@ -396,7 +394,8 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM. else ft end - args = classify_arguments(job, orig_ft) + source_sig = Base.signature_type(job.source.f, job.source.tt)::Type + args = classify_arguments(source_sig, orig_ft) filter!(args) do arg arg.cc != GHOST end diff --git a/src/spirv.jl b/src/spirv.jl index 632b6c51..1ebcfec4 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -239,7 +239,8 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F else ft end - args = classify_arguments(job, orig_ft) + source_sig = Base.signature_type(job.source.f, job.source.tt)::Type + args = classify_arguments(source_sig, entry_f) filter!(args) do arg arg.cc != GHOST end