Skip to content

Commit

Permalink
Merge pull request #32 from JunNishimura/feature/add_if_syntax
Browse files Browse the repository at this point in the history
add if syntax
  • Loading branch information
JunNishimura authored Jul 21, 2024
2 parents 838d5d3 + 71eb40e commit 5956d14
Show file tree
Hide file tree
Showing 10 changed files with 752 additions and 6 deletions.
7 changes: 7 additions & 0 deletions ast/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ type Symbol struct {
func (s *Symbol) TokenLiteral() string { return s.Token.Literal }
func (s *Symbol) String() string { return s.Value }

type True struct {
Token token.Token
}

func (t *True) TokenLiteral() string { return t.Token.Literal }
func (t *True) String() string { return "T" }

type SpecialForm struct {
Token token.Token
Value string
Expand Down
136 changes: 136 additions & 0 deletions evaluator/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,140 @@ var builtinFuncs = map[string]*object.Builtin{
return &object.Integer{Value: quotient}
},
},
"=": {
Fn: func(args ...object.Object) object.Object {
if len(args) == 0 {
return newError("wrong number of arguments. got=%d, want=1", len(args))
}

compTo, ok := args[0].(*object.Integer)
if !ok {
return newError("argument to `=` must be INTEGER, got %s", args[0].Type())
}
for _, arg := range args[1:] {
compFrom, ok := arg.(*object.Integer)
if !ok {
return newError("argument to `=` must be INTEGER, got %s", arg.Type())
}
if compFrom.Value != compTo.Value {
return Nil
}
}
return True
},
},
"/=": {
Fn: func(args ...object.Object) object.Object {
if len(args) == 0 {
return newError("wrong number of arguments. got=%d, want=1", len(args))
}

compTo, ok := args[0].(*object.Integer)
if !ok {
return newError("argument to `/=` must be INTEGER, got %s", args[0].Type())
}
for _, arg := range args[1:] {
compFrom, ok := arg.(*object.Integer)
if !ok {
return newError("argument to `/=` must be INTEGER, got %s", arg.Type())
}
if compFrom.Value == compTo.Value {
return Nil
}
}
return True
},
},
"<": {
Fn: func(args ...object.Object) object.Object {
if len(args) == 0 {
return newError("wrong number of arguments. got=%d, want=1", len(args))
}

compTo, ok := args[0].(*object.Integer)
if !ok {
return newError("argument to `<` must be INTEGER, got %s", args[0].Type())
}
for _, arg := range args[1:] {
compFrom, ok := arg.(*object.Integer)
if !ok {
return newError("argument to `<` must be INTEGER, got %s", arg.Type())
}
if compTo.Value >= compFrom.Value {
return Nil
}
compTo = compFrom
}
return True
},
},
"<=": {
Fn: func(args ...object.Object) object.Object {
if len(args) == 0 {
return newError("wrong number of arguments. got=%d, want=1", len(args))
}

compTo, ok := args[0].(*object.Integer)
if !ok {
return newError("argument to `<=` must be INTEGER, got %s", args[0].Type())
}
for _, arg := range args[1:] {
compFrom, ok := arg.(*object.Integer)
if !ok {
return newError("argument to `<=` must be INTEGER, got %s", arg.Type())
}
if compTo.Value > compFrom.Value {
return Nil
}
compTo = compFrom
}
return True
},
},
">": {
Fn: func(args ...object.Object) object.Object {
if len(args) == 0 {
return newError("wrong number of arguments. got=%d, want=1", len(args))
}

compTo, ok := args[0].(*object.Integer)
if !ok {
return newError("argument to `>` must be INTEGER, got %s", args[0].Type())
}
for _, arg := range args[1:] {
compFrom, ok := arg.(*object.Integer)
if !ok {
return newError("argument to `>` must be INTEGER, got %s", arg.Type())
}
if compTo.Value <= compFrom.Value {
return Nil
}
compTo = compFrom
}
return True
},
},
">=": {
Fn: func(args ...object.Object) object.Object {
if len(args) == 0 {
return newError("wrong number of arguments. got=%d, want=1", len(args))
}

compTo, ok := args[0].(*object.Integer)
if !ok {
return newError("argument to `>=` must be INTEGER, got %s", args[0].Type())
}
for _, arg := range args[1:] {
compFrom, ok := arg.(*object.Integer)
if !ok {
return newError("argument to `>=` must be INTEGER, got %s", arg.Type())
}
if compTo.Value < compFrom.Value {
return Nil
}
compTo = compFrom
}
return True
},
},
}
64 changes: 63 additions & 1 deletion evaluator/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
)

var (
Nil = &object.Nil{}
Nil = &object.Nil{}
True = &object.True{}
)

func Eval(sexp ast.SExpression, env *object.Environment) object.Object {
Expand All @@ -24,6 +25,8 @@ func Eval(sexp ast.SExpression, env *object.Environment) object.Object {
return right
}
return evalPrefixAtom(sexp.Operator, right)
case *ast.True:
return True
case *ast.Nil:
return Nil
case *ast.Symbol:
Expand Down Expand Up @@ -227,6 +230,8 @@ func evalSpecialForm(sexp *ast.ConsCell, env *object.Environment) object.Object
return evalQuote(sexp)
case "backquote":
return evalBackquote(sexp, env)
case "if":
return evalIf(sexp, env)
}

return newError("unknown special form: %s", spForm.Value)
Expand Down Expand Up @@ -366,3 +371,60 @@ func convertObjectToSExpression(obj object.Object) ast.SExpression {
return nil
}
}

func evalIf(consCell *ast.ConsCell, env *object.Environment) object.Object {
spForm, ok := consCell.Car().(*ast.SpecialForm)
if !ok {
return newError("expect special form, got %T", consCell.Car())
}
if spForm.Token.Type != token.IF {
return newError("expect special form if, got %s", spForm.Token.Type)
}

cdr, ok := consCell.Cdr().(*ast.ConsCell)
if !ok {
return newError("not defined if condition")
}

// evaluate the condition
cadr := cdr.Car()
condition := Eval(cadr, env)
if isError(condition) {
return condition
}

cddr, ok := cdr.Cdr().(*ast.ConsCell)
if !ok {
return newError("not defined if consequent")
}

// if condition is true, evaluate the consequent
if isTruthy(condition) {
caddr := cddr.Car()
return Eval(caddr, env)
}

// if alternative is not defined, return nil
cdddr, ok := cddr.Cdr().(*ast.ConsCell)
if !ok {
if _, ok := cddr.Cdr().(*ast.Nil); ok {
return Nil
}
return newError("invalid if alternative")
}

// evaluate the alternative
cadddr := cdddr.Car()
return Eval(cadddr, env)
}

func isTruthy(obj object.Object) bool {
switch obj.(type) {
case *object.True:
return true
case *object.Nil:
return false
default:
return true
}
}
89 changes: 89 additions & 0 deletions evaluator/evaluator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,92 @@ func TestBackQuote(t *testing.T) {
}
}
}

func TestTrueExpression(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"t", "T"},
}

for _, tt := range tests {
evaluated := testEval(tt.input)
if evaluated.Inspect() != tt.expected {
t.Errorf("expected=%q, got=%q", tt.expected, evaluated.Inspect())
}
}
}

func testComparisonObject(t *testing.T, obj object.Object, expected string) {
if expected == "T" {
if obj != True {
t.Errorf("object is not TRUE. got=%T (%+v)", obj, obj)
}
} else {
if obj != Nil {
t.Errorf("object is not NIL. got=%T (%+v)", obj, obj)
}
}
}

func TestComparisonExpression(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"(= 1)", "T"},
{"(= 1 1)", "T"},
{"(= 1 2)", "NIL"},
{"(/= 1)", "T"},
{"(/= 1 1)", "NIL"},
{"(/= 1 2)", "T"},
{"(< 1)", "T"},
{"(< 1 1)", "NIL"},
{"(< 1 2)", "T"},
{"(< 2 1)", "NIL"},
{"(> 1)", "T"},
{"(> 1 1)", "NIL"},
{"(> 1 2)", "NIL"},
{"(> 2 1)", "T"},
{"(<= 1)", "T"},
{"(<= 1 0)", "NIL"},
{"(<= 1 1)", "T"},
{"(<= 1 2)", "T"},
{"(<= 1)", "T"},
{"(<= 2 1)", "NIL"},
{"(<= 2 2)", "T"},
{"(<= 2 3)", "T"},
}

for _, tt := range tests {
evaluated := testEval(tt.input)
testComparisonObject(t, evaluated, tt.expected)
}
}

func TestIfExpression(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"(if t 10)", "10"},
{"(if 0 10)", "10"},
{"(if 1 10)", "10"},
{"(if nil 10)", "nil"},
{"(if t 10 20)", "10"},
{"(if nil 10 20)", "20"},
{"(if (= 1 1) 10 20)", "10"},
{"(if (= 1 2) 10 20)", "20"},
{"(if t (+ 1 1) (+ 2 2))", "2"},
{"(if nil (+ 1 1) (+ 2 2))", "4"},
{"(if (= ((lambda (x y) (+ x y)) 1 1) (* 1 2)) 10 20)", "10"},
}

for _, tt := range tests {
evaluated := testEval(tt.input)
if evaluated.Inspect() != tt.expected {
t.Errorf("expected=%q, got=%q", tt.expected, evaluated.Inspect())
}
}
}
36 changes: 35 additions & 1 deletion lexer/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ func (l *Lexer) NextToken() token.Token {
tok = newToken(token.BACKQUOTE, l.curChar)
case ',':
tok = newToken(token.COMMA, l.curChar)
case '<':
if l.peekChar() == '=' {
lt := l.curChar
l.readChar()
literal := string(lt) + string(l.curChar)
tok = token.Token{Type: token.SYMBOL, Literal: literal}
} else {
tok = newToken(token.SYMBOL, l.curChar)
}
case '>':
if l.peekChar() == '=' {
gt := l.curChar
l.readChar()
literal := string(gt) + string(l.curChar)
tok = token.Token{Type: token.SYMBOL, Literal: literal}
} else {
tok = newToken(token.SYMBOL, l.curChar)
}
case '/':
if l.peekChar() == '=' {
ne := l.curChar
l.readChar()
literal := string(ne) + string(l.curChar)
tok = token.Token{Type: token.SYMBOL, Literal: literal}
} else {
tok = newToken(token.SYMBOL, l.curChar)
}
case 0:
tok.Literal = ""
tok.Type = token.EOF
Expand Down Expand Up @@ -104,7 +131,7 @@ func isDigit(ch byte) bool {

func isSpecialChar(ch byte) bool {
return ch == '*' ||
ch == '/'
ch == '='
}

func isSymbol(ch byte) bool {
Expand All @@ -126,3 +153,10 @@ func (l *Lexer) readNumber() string {
}
return l.input[startPos:l.curPos]
}

func (l *Lexer) peekChar() byte {
if l.nextPos >= len(l.input) {
return 0
}
return l.input[l.nextPos]
}
Loading

0 comments on commit 5956d14

Please sign in to comment.