diff --git a/src/indexnotation/analyzers.jl b/src/indexnotation/analyzers.jl index 26bacd2a..ba143074 100644 --- a/src/indexnotation/analyzers.jl +++ b/src/indexnotation/analyzers.jl @@ -118,7 +118,7 @@ function getinputtensorobjects(ex) end append!(list, list2) end - elseif isa(ex, Expr) && ex.head in (:for, :function) + elseif isa(ex, Expr) && ex.head in (:for, :while) append!(list, getinputtensorobjects(ex.args[2])) elseif isa(ex, Expr) && ex.head == :call && ex.args[1] == :scalar append!(list, gettensorobjects(ex.args[2])) @@ -142,7 +142,7 @@ function getoutputtensorobjects(ex) end append!(list, list2) end - elseif isa(ex, Expr) && ex.head in (:for, :function) + elseif isa(ex, Expr) && ex.head in (:for, :while) append!(list, getoutputtensorobjects(ex.args[2])) end return unique!(list) @@ -159,7 +159,7 @@ function getnewtensorobjects(ex) for e in ex.args append!(list, getnewtensorobjects(e)) end - elseif isa(ex, Expr) && ex.head in (:for, :function) + elseif isa(ex, Expr) && ex.head in (:for, :while) append!(list, getnewtensorobjects(ex.args[2])) end return unique!(list) diff --git a/src/indexnotation/parser.jl b/src/indexnotation/parser.jl index 48b54b4a..6c36276a 100644 --- a/src/indexnotation/parser.jl +++ b/src/indexnotation/parser.jl @@ -20,6 +20,9 @@ mutable struct TensorParser end function (parser::TensorParser)(ex) + if ex isa Expr && ex.head == :function + return Expr(:function, ex.args[1], parser(ex.args[2])) + end for p in parser.preprocessors ex = p(ex) end diff --git a/src/indexnotation/preprocessors.jl b/src/indexnotation/preprocessors.jl index 7f6cd502..f2ee0f29 100644 --- a/src/indexnotation/preprocessors.jl +++ b/src/indexnotation/preprocessors.jl @@ -42,6 +42,9 @@ end # replace all tensor objects by a function of that object function replacetensorobjects(f, ex::Expr) + # first try to replace ex completely + ex2 = f(ex, nothing, nothing) + ex2 !== ex && return ex2 if istensor(ex) obj, leftind, rightind = decomposetensor(ex) return Expr(ex.head, f(obj, leftind, rightind), ex.args[2:end]...) @@ -49,7 +52,7 @@ function replacetensorobjects(f, ex::Expr) return Expr(ex.head, (replacetensorobjects(f, e) for e in ex.args)...) end end -replacetensorobjects(f, ex) = ex +replacetensorobjects(f, ex) = f(ex, nothing, nothing) # expandconj: conjugate individual terms or factors instead of a whole expression function expandconj(ex::Expr)