Skip to content

Commit

Permalink
Merge pull request #26 from JunNishimura/feature/add_macro
Browse files Browse the repository at this point in the history
add macro
  • Loading branch information
JunNishimura authored Jul 15, 2024
2 parents eb8d73b + 0d6559e commit 629581b
Show file tree
Hide file tree
Showing 8 changed files with 596 additions and 7 deletions.
18 changes: 14 additions & 4 deletions ast/modify.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
package ast

import (
"slices"
)

type ModifierFun func(SExpression) SExpression

func Modify(sexp SExpression, modifier ModifierFun, symbolValue string) SExpression {
func Modify(sexp SExpression, modifier ModifierFun, triggers []string) SExpression {
switch st := sexp.(type) {
case *Program:
for i, sexp := range st.Expressions {
st.Expressions[i] = Modify(sexp, modifier, triggers)
}
case *ConsCell:
if car, ok := st.CarField.(*Symbol); ok && car.Value == symbolValue {
if symbol, ok := st.CarField.(*Symbol); ok && slices.Contains(triggers, symbol.Value) {
// return not only the car field but also the cdr field
// since args(cdr field) are needed to modify the AST
return modifier(sexp)
}
st.CarField = Modify(st.CarField, modifier, symbolValue)
st.CdrField = Modify(st.CdrField, modifier, symbolValue)
st.CarField = Modify(st.CarField, modifier, triggers)
st.CdrField = Modify(st.CdrField, modifier, triggers)
}

return sexp
Expand Down
4 changes: 3 additions & 1 deletion evaluator/backquote.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func evalUnquote(sexp ast.SExpression, env *object.Environment) ast.SExpression
evaluated := Eval(cdr.Car(), env)

return convertObjectToSExpression(evaluated)
}, "unquote")
}, []string{"unquote"})
}

func convertObjectToSExpression(obj object.Object) ast.SExpression {
Expand All @@ -48,6 +48,8 @@ func convertObjectToSExpression(obj object.Object) ast.SExpression {
Literal: fmt.Sprintf("%d", obj.Value),
}
return &ast.IntegerLiteral{Token: t, Value: obj.Value}
case *object.Quote:
return obj.SExpression
default:
return nil
}
Expand Down
235 changes: 235 additions & 0 deletions evaluator/macro_expansion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
package evaluator

import (
"github.com/JunNishimura/go-lisp/ast"
"github.com/JunNishimura/go-lisp/object"
)

var macroNames = []string{}

func DefineMacros(program *ast.Program, env *object.Environment) {
definitions := []int{}

for i, exp := range program.Expressions {
if isMacroDefinition(exp) {
addMacro(exp, env)
definitions = append(definitions, i)
}
}

for i := len(definitions) - 1; i >= 0; i-- {
definitionIndex := definitions[i]
program.Expressions = append(program.Expressions[:definitionIndex], program.Expressions[definitionIndex+1:]...)
}
}

func isMacroDefinition(sexp ast.SExpression) bool {
consCell, ok := sexp.(*ast.ConsCell)
if !ok {
return false
}

car, ok := consCell.Car().(*ast.Symbol)
if !ok {
return false
}

return car.Value == "defmacro"
}

func addMacro(sexp ast.SExpression, env *object.Environment) {
macroName, ok := getMacroName(sexp)
if !ok {
return
}

params, ok := getMacroParams(sexp)
if !ok {
return
}

body, ok := getMacroBody(sexp)
if !ok {
return
}

macro := &object.Macro{
Parameters: params,
Body: body,
Env: env,
}

macroNames = append(macroNames, macroName)
env.Set(macroName, macro)
}

func getMacroName(sexp ast.SExpression) (string, bool) {
consCell, ok := sexp.(*ast.ConsCell)
if !ok {
return "", false
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return "", false
}

symbol, ok := consCell.Car().(*ast.Symbol)
if !ok {
return "", false
}

return symbol.Value, true
}

func getMacroParams(sexp ast.SExpression) ([]*ast.Symbol, bool) {
consCell, ok := sexp.(*ast.ConsCell)
if !ok {
return nil, false
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return nil, false
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return nil, false
}

paramConsCell, ok := consCell.Car().(*ast.ConsCell)
if !ok {
if _, ok := consCell.Car().(*ast.Nil); ok {
// No parameters
return []*ast.Symbol{}, true
}
return nil, false
}

params := []*ast.Symbol{}
for {
symbol, ok := paramConsCell.Car().(*ast.Symbol)
if !ok {
return nil, false
}

params = append(params, symbol)

if _, ok := paramConsCell.Cdr().(*ast.Nil); ok {
break
}

paramConsCell, ok = paramConsCell.Cdr().(*ast.ConsCell)
if !ok {
return nil, false
}
}

return params, true
}

func getMacroBody(sexp ast.SExpression) (ast.SExpression, bool) {
consCell, ok := sexp.(*ast.ConsCell)
if !ok {
return nil, false
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return nil, false
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return nil, false
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return nil, false
}

return consCell.Car(), true
}

func ExpandMacros(program ast.SExpression, env *object.Environment) ast.SExpression {
return ast.Modify(program, func(sexp ast.SExpression) ast.SExpression {
consCell, ok := sexp.(*ast.ConsCell)
if !ok {
return sexp
}

macro, ok := isMacroCall(consCell, env)
if !ok {
return sexp
}

args := quoteArgs(consCell)

evalEnv := extendMacroEnv(macro, args)

evaluated := Eval(macro.Body, evalEnv)

quote, ok := evaluated.(*object.Quote)
if !ok {
panic("we only support returning AST-nodes from macros")
}

return quote.SExpression
}, macroNames)
}

func isMacroCall(consCell *ast.ConsCell, env *object.Environment) (*object.Macro, bool) {
symbol, ok := consCell.Car().(*ast.Symbol)

if !ok {
return nil, false
}

obj, ok := env.Get(symbol.Value)
if !ok {
return nil, false
}

macro, ok := obj.(*object.Macro)
if !ok {
return nil, false
}

return macro, true
}

func quoteArgs(consCell *ast.ConsCell) []*object.Quote {
args := []*object.Quote{}

consCell, ok := consCell.Cdr().(*ast.ConsCell)
if !ok {
return args
}

for {
args = append(args, &object.Quote{SExpression: consCell.Car()})

if _, ok := consCell.Cdr().(*ast.Nil); ok {
break
}

consCell, ok = consCell.Cdr().(*ast.ConsCell)
if !ok {
return args
}
}

return args
}

func extendMacroEnv(macro *object.Macro, args []*object.Quote) *object.Environment {
env := object.NewEnclosedEnvironment(macro.Env)

for i, param := range macro.Parameters {
env.Set(param.Value, args[i])
}

return env
}
93 changes: 93 additions & 0 deletions evaluator/macro_expansion_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package evaluator

import (
"testing"

"github.com/JunNishimura/go-lisp/ast"
"github.com/JunNishimura/go-lisp/lexer"
"github.com/JunNishimura/go-lisp/object"
"github.com/JunNishimura/go-lisp/parser"
)

func TestDefineMacros(t *testing.T) {
input := "(defmacro myMacro (x y) '(+ x y))"

env := object.NewEnvironment()
program := testParseProgram(input)

DefineMacros(program, env)

obj, ok := env.Get("myMacro")
if !ok {
t.Fatalf("macro not in environment")
}

macro, ok := obj.(*object.Macro)
if !ok {
t.Fatalf("object is not Macro. got=%T (%+v)", obj, obj)
}

if len(macro.Parameters) != 2 {
t.Fatalf("wrong number of identifiers. got=%d", len(macro.Parameters))
}

if macro.Parameters[0].String() != "x" {
t.Fatalf("parameter is not 'x'. got=%q", macro.Parameters[0])
}
if macro.Parameters[1].String() != "y" {
t.Fatalf("parameter is not 'y'. got=%q", macro.Parameters[1])
}

expectedBody := "(quote (+ x y))"
if macro.Body.String() != expectedBody {
t.Fatalf("body is not %q. got=%q", expectedBody, macro.Body.String())
}
}

func testParseProgram(input string) *ast.Program {
l := lexer.New(input)
p := parser.New(l)
return p.ParseProgram()
}

func TestExpandMacros(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "expands macro which has no arguments",
input: `
(defmacro hoge () '1)
(hoge)
`,
expected: "1",
},
{
name: "expands macro which has arguments",
input: `
(defmacro hoge (x y) ` + "`" + `(- ,y ,x))
(hoge (+ 2 2) (- 10 5))
`,
expected: "(- (- 10 5) (+ 2 2))",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expected := testParseProgram(tt.expected)

program := testParseProgram(tt.input)
env := object.NewEnvironment()

DefineMacros(program, env)

expanded := ExpandMacros(program, env)

if expanded.String() != expected.String() {
t.Errorf("not equal. got=%q, want=%q", expanded.String(), expected.String())
}
})
}
}
Loading

0 comments on commit 629581b

Please sign in to comment.