Skip to content

Commit

Permalink
Merge pull request #636 from finch-tensor/kbd-small-finch-logic-changes
Browse files Browse the repository at this point in the history
Kbd small finch logic changes
  • Loading branch information
kylebd99 authored Nov 14, 2024
2 parents bbb7961 + a5d6a6a commit a3e0926
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/FinchLogic/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ function LogicNode(kind::LogicNodeKind, args::Vector)
return LogicNode(kind, args[1], Any, LogicNode[])
elseif kind === deferred && length(args) == 2
return LogicNode(kind, args[1], args[2], LogicNode[])
elseif kind === deferred && length(args) == 3
return LogicNode(kind, (args[1], args[3]), args[2], LogicNode[])
else
args = LogicNode_concatenate_args(args)
if (kind === table && length(args) >= 1) ||
Expand All @@ -231,7 +233,8 @@ end
function Base.getproperty(node::LogicNode, sym::Symbol)
if sym === :kind || sym === :val || sym === :type || sym === :children
return Base.getfield(node, sym)
elseif node.kind === deferred && sym === :ex node.val
elseif node.kind === deferred && sym === :ex node.val isa Tuple ? node.val[1] : node.val
elseif node.kind === deferred && sym === :imm node.val[2]
elseif node.kind === field && sym === :name node.val::Symbol
elseif node.kind === alias && sym === :name node.val::Symbol
elseif node.kind === table && sym === :tns node.children[1]
Expand Down
2 changes: 1 addition & 1 deletion src/FinchNotation/instances.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ Base.:(==)(a::VariableInstance, b::VariableInstance) = false
Base.:(==)(a::VariableInstance{tag}, b::VariableInstance{tag}) where {tag} = true
function Base.:(==)(a::FinchNodeInstance, b::FinchNodeInstance)
return operation(a) == operation(b) && arguments(a) == arguments(b)
end
end
73 changes: 73 additions & 0 deletions src/FinchNotation/syntax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,76 @@ function display_statement(io, mime, node::Union{FinchNode, FinchNodeInstance},
error("unimplemented")
end
end

finch_unparse_program(ctx, node) = finch_unparse_program(ctx, finch_leaf(node))
function finch_unparse_program(ctx, node::Union{FinchNode, FinchNodeInstance})
if operation(node) === value
node.val
elseif operation(node) === literal
node.val
elseif operation(node) === index
node.name
elseif operation(node) === variable
node.name
elseif operation(node) === cached
finch_unparse_program(ctx, node.arg)
elseif operation(node) === tag
@assert operation(node.var) === variable
node.var.name
elseif operation(node) === virtual
if node.val == dimless
:_
else
ctx(node)
end
elseif operation(node) === access
tns = finch_unparse_program(ctx, node.tns)
idxs = map(x -> finch_unparse_program(ctx, x), node.idxs)
:($tns[$(idxs...)])
elseif operation(node) === call
op = finch_unparse_program(ctx, node.op)
args = map(x -> finch_unparse_program(ctx, x), node.args)
:($op($(args...)))
elseif operation(node) === loop
idx = finch_unparse_program(ctx, node.idx)
ext = finch_unparse_program(ctx, node.ext)
body = finch_unparse_program(ctx, node.body)
:(for $idx = $ext; $body end)
elseif operation(node) === define
lhs = finch_unparse_program(ctx, node.lhs)
rhs = finch_unparse_program(ctx, node.rhs)
body = finch_unparse_program(ctx, node.body)
:(let $lhs = $rhs; $body end)
elseif operation(node) === sieve
cond = finch_unparse_program(ctx, node.cond)
body = finch_unparse_program(ctx, node.body)
:(if $cond; $body end)
elseif operation(node) === assign
lhs = finch_unparse_program(ctx, node.lhs)
op = finch_unparse_program(ctx, node.op)
rhs = finch_unparse_program(ctx, node.rhs)
if haskey(incs, op)
Expr(incs[op], lhs, rhs)
else
:($lhs <<$op>>= $rhs)
end
elseif operation(node) === declare
tns = finch_unparse_program(ctx, node.tns)
init = finch_unparse_program(ctx, node.init)
:($tns .= $init)
elseif operation(node) === freeze
tns = finch_unparse_program(ctx, node.tns)
:(@freeze($tns))
elseif operation(node) === thaw
tns = finch_unparse_program(ctx, node.tns)
:(@thaw($tns))
elseif operation(node) === yieldbind
args = map(x -> finch_unparse_program(ctx, x), node.args)
:(return($(args...)))
elseif operation(node) === block
bodies = map(x -> finch_unparse_program(ctx, x), node.bodies)
Expr(:block, bodies...)
else
error("unimplemented")
end
end
6 changes: 3 additions & 3 deletions src/scheduler/LogicExecutor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ is given as input to the program.
"""
function defer_tables(ex, node::LogicNode)
if @capture node table(~tns::isimmediate, ~idxs...)
table(deferred(:($ex.tns.val), typeof(tns.val)), map(enumerate(node.idxs)) do (i, idx)
table(deferred(:($ex.tns.val), typeof(tns.val), tns.val), map(enumerate(node.idxs)) do (i, idx)
defer_tables(:($ex.idxs[$i]), idx)
end)
elseif istree(node)
Expand All @@ -29,7 +29,7 @@ function cache_deferred!(ctx, root::LogicNode)
get!(seen, node.val) do
var = freshen(ctx, :V)
push_preamble!(ctx, :($var = $(node.ex)::$(node.type)))
deferred(var, node.type)
deferred(var, node.type, node.imm)
end
end))(root)
end
Expand Down Expand Up @@ -90,4 +90,4 @@ end

function (ctx::LogicExecutorCode)(prgm)
return logic_executor_code(ctx.ctx, prgm)
end
end
15 changes: 14 additions & 1 deletion test/test_interface.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
using Finch: AsArray
using Finch: AsArray, JuliaContext
using Finch.FinchNotation: finch_unparse_program, @finch_program_instance

@testset "interface" begin

@info "Testing Finch Interface"

@testset "finch_unparse" begin
prgm = @finch_program quote
A .= 0
for i = _
A[i] += 1
end
end
@test prgm.val == @finch_program $(finch_unparse_program(JuliaContext(), prgm))
end

#https://github.com/finch-tensor/Finch.jl/issues/383
let
A = [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0]
Expand Down Expand Up @@ -814,4 +825,6 @@ using Finch: AsArray
B = dropfills!(swizzle(A, 2, 1), [0.0 0.0 4.4; 1.1 0.0 0.0; 2.2 0.0 5.5; 3.3 0.0 0.0])
@test B == swizzle(Tensor(Dense{Int64}(SparseList{Int64}(Element{0.0, Float64, Int64}([4.4, 1.1, 2.2, 5.5, 3.3]), 3, [1, 2, 3, 5, 6], [3, 1, 1, 3, 1]), 4)), 2, 1)
end


end

0 comments on commit a3e0926

Please sign in to comment.