diff options
| -rw-r--r-- | pkg/lang/ast/expr.go | 8 | ||||
| -rw-r--r-- | pkg/lang/ast/stmt.go | 5 | ||||
| -rw-r--r-- | pkg/lang/compiler/compiler.go | 214 | ||||
| -rw-r--r-- | pkg/lang/compiler/compiler_test.go | 100 | ||||
| -rw-r--r-- | pkg/lang/parser/parser_test.go | 48 | ||||
| -rw-r--r-- | pkg/lang/parser/stmts.go | 28 |
6 files changed, 289 insertions, 114 deletions
diff --git a/pkg/lang/ast/expr.go b/pkg/lang/ast/expr.go index f98f9b3..b0ed599 100644 --- a/pkg/lang/ast/expr.go +++ b/pkg/lang/ast/expr.go @@ -1,6 +1,8 @@ package ast -import "jinx/pkg/libs/source" +import ( + "jinx/pkg/libs/source" +) type ExprKind int @@ -24,6 +26,10 @@ const ( type Expr ExprT[any] +func (e Expr) IsEmpty() bool { + return e == Expr{} +} + type ExprT[T any] struct { At source.Loc Kind ExprKind diff --git a/pkg/lang/ast/stmt.go b/pkg/lang/ast/stmt.go index e25dee9..6395f57 100644 --- a/pkg/lang/ast/stmt.go +++ b/pkg/lang/ast/stmt.go @@ -47,10 +47,7 @@ type StmtVarDecl struct { } type StmtIf struct { - Cond Expr - Then BlockNode - Elifs []CondNode - Else BlockNode + Conds []CondNode } type StmtTry struct { diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go index 4d7f4a1..f4a900a 100644 --- a/pkg/lang/compiler/compiler.go +++ b/pkg/lang/compiler/compiler.go @@ -7,90 +7,148 @@ import ( ) type Compiler struct { - ast ast.Program - builder code.Builder + ast ast.Program scopes ScopeChain } func New(ast ast.Program) *Compiler { return &Compiler{ - ast: ast, - builder: code.NewBuilder(), + ast: ast, scopes: NewScopeChain(), } } func (comp *Compiler) Compile() (code.Code, error) { + target := code.NewBuilder() + for _, stmt := range comp.ast.Stmts { - if err := comp.compileStmt(stmt); err != nil { + if err := comp.compileStmt(&target, stmt); err != nil { return code.Code{}, err } } - return comp.builder.Build(), nil + target.AppendOp(code.OpHalt) + + return target.Build(), nil } -func (comp *Compiler) compileStmt(stmt ast.Stmt) error { +func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { var err error switch stmt.Kind { case ast.StmtKindEmpty: // Do nothing. case ast.StmtKindVarDecl: decl := stmt.Value.(ast.StmtVarDecl) - err = comp.compileVarDeclStmt(decl) + err = comp.compileVarDeclStmt(t, decl) + case ast.StmtKindIf: + ifstmt := stmt.Value.(ast.StmtIf) + err = comp.compileIfStmt(t, ifstmt) case ast.StmtKindExpr: expr := stmt.Value.(ast.StmtExpr).Value - err = comp.compileExpr(expr) + err = comp.compileExpr(t, expr) default: - panic("statements other than expressions, variable declarations, var and empties not implemented") + panic(fmt.Errorf("statement of kind %v not implemented", stmt.Kind)) } return err } -func (comp *Compiler) compileVarDeclStmt(decl ast.StmtVarDecl) error { +func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error { if !comp.scopes.Declare(decl.Name.Value) { return fmt.Errorf("variable %s already declared", decl.Name.Value) } - if err := comp.compileExpr(decl.Value); err != nil { + if err := comp.compileExpr(t, decl.Value); err != nil { return err } return nil } -func (comp *Compiler) compileExpr(expr ast.Expr) error { +func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error { + // push_false -> cond // only on ifs and elifs + // jf @elif -> condjmp // only on ifs and elifs + + // push_int 1 -> then + // jmp @end -> thenjmp // except on last cond + + subUnits := make([]code.Builder, 0, len(ifstmt.Conds)) + + totalLength := t.Len() + jmpLength := 9 // OP + Uint + + for i, cond := range ifstmt.Conds { + thenTarget := code.NewBuilder() // then + if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil { + return err + } + + totalLength += thenTarget.Len() + if i != len(ifstmt.Conds)-1 { + totalLength += jmpLength // thenjmp + } + + subUnitTarget := code.NewBuilder() + if !cond.Cond.IsEmpty() { + // cond + if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil { + return err + } + + totalLength += subUnitTarget.Len() + jmpLength // condjmp + + subUnitTarget.AppendOp(code.OpJf) + subUnitTarget.AppendInt(int64(totalLength)) + } + + subUnitTarget.AppendRaw(thenTarget.Code()) + + subUnits = append(subUnits, subUnitTarget) + } + + for i, subUnit := range subUnits { + if i != len(ifstmt.Conds)-1 { + subUnit.AppendOp(code.OpJmp) + subUnit.AppendInt(int64(totalLength)) + } + + t.AppendRaw(subUnit.Code()) + } + + return nil +} + +func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error { switch expr.Kind { case ast.ExprKindBinary: - return comp.compileBinaryExpr(expr.Value.(ast.ExprBinary)) + return comp.compileBinaryExpr(t, expr.Value.(ast.ExprBinary)) case ast.ExprKindUnary: - return comp.compileUnaryExpr(expr.Value.(ast.ExprUnary)) + return comp.compileUnaryExpr(t, expr.Value.(ast.ExprUnary)) case ast.ExprKindCall: - return comp.compileCallExpr(expr.Value.(ast.ExprCall)) + return comp.compileCallExpr(t, expr.Value.(ast.ExprCall)) case ast.ExprKindSubscription: - return comp.compileSubscriptionExpr(expr.Value.(ast.ExprSubscription)) + return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription)) case ast.ExprKindGroup: - return comp.compileGroupExpr(expr.Value.(ast.ExprGroup)) + return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup)) case ast.ExprKindFnLit: panic("not implemented") case ast.ExprKindArrayLit: panic("not implemented") case ast.ExprKindIdent: - return comp.compileIdentExpr(expr.Value.(ast.ExprIdent)) + return comp.compileIdentExpr(t, expr.Value.(ast.ExprIdent)) case ast.ExprKindIntLit: - return comp.compileIntLitExpr(expr.Value.(ast.ExprIntLit)) + return comp.compileIntLitExpr(t, expr.Value.(ast.ExprIntLit)) case ast.ExprKindFloatLit: - return comp.compileFloatLitExpr(expr.Value.(ast.ExprFloatLit)) + return comp.compileFloatLitExpr(t, expr.Value.(ast.ExprFloatLit)) case ast.ExprKindStringLit: - return comp.compileStringLitExpr(expr.Value.(ast.ExprStringLit)) + return comp.compileStringLitExpr(t, expr.Value.(ast.ExprStringLit)) case ast.ExprKindBoolLit: - return comp.compileBoolLitExpr(expr.Value.(ast.ExprBoolLit)) + return comp.compileBoolLitExpr(t, expr.Value.(ast.ExprBoolLit)) case ast.ExprKindNullLit: - return comp.compileNullLitExpr(expr.Value.(ast.ExprNullLit)) + return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit)) case ast.ExprKindThis: panic("not implemented") default: @@ -98,50 +156,50 @@ func (comp *Compiler) compileExpr(expr ast.Expr) error { } } -func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error { +func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Op == ast.BinOpAssign { - return comp.compileAssignExpr(expr) + return comp.compileAssignExpr(t, expr) } - if err := comp.compileExpr(expr.Left); err != nil { + if err := comp.compileExpr(t, expr.Left); err != nil { return err } - if err := comp.compileExpr(expr.Right); err != nil { + if err := comp.compileExpr(t, expr.Right); err != nil { return err } switch expr.Op { case ast.BinOpPlus: - comp.builder.AppendOp(code.OpAdd) + t.AppendOp(code.OpAdd) case ast.BinOpMinus: - comp.builder.AppendOp(code.OpSub) + t.AppendOp(code.OpSub) case ast.BinOpStar: - // comp.builder.AppendOp(code.OpMul) + // t.AppendOp(code.OpMul) panic("not implemented") case ast.BinOpSlash: - // comp.builder.AppendOp(code.OpDiv) + // t.AppendOp(code.OpDiv) panic("not implemented") case ast.BinOpPercent: - // comp.builder.AppendOp(code.OpMod) + // t.AppendOp(code.OpMod) panic("not implemented") case ast.BinOpEq: - // comp.builder.AppendOp(code.OpEq) + // t.AppendOp(code.OpEq) panic("not implemented") case ast.BinOpNeq: - // comp.builder.AppendOp(code.OpNeq) + // t.AppendOp(code.OpNeq) panic("not implemented") case ast.BinOpLt: - // comp.builder.AppendOp(code.OpLt) + // t.AppendOp(code.OpLt) panic("not implemented") case ast.BinOpLte: - comp.builder.AppendOp(code.OpLte) + t.AppendOp(code.OpLte) case ast.BinOpGt: - // comp.builder.AppendOp(code.OpGt) + // t.AppendOp(code.OpGt) panic("not implemented") case ast.BinOpGte: - // comp.builder.AppendOp(code.OpGte) + // t.AppendOp(code.OpGte) panic("not implemented") default: panic("unknown binary operator") @@ -150,7 +208,7 @@ func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error { return nil } -func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error { +func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Left.Kind != ast.ExprKindIdent { return fmt.Errorf("lvalues other than identifiers not implemented") } @@ -161,18 +219,18 @@ func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error { return fmt.Errorf("variable %s not declared", name) } - if err := comp.compileExpr(expr.Right); err != nil { + if err := comp.compileExpr(t, expr.Right); err != nil { return err } - comp.builder.AppendOp(code.OpSetLocal) - comp.builder.AppendInt(int64(symbol.localIndex)) + t.AppendOp(code.OpSetLocal) + t.AppendInt(int64(symbol.localIndex)) return nil } -func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error { - if err := comp.compileExpr(expr.Value); err != nil { +func (comp *Compiler) compileUnaryExpr(t *code.Builder, expr ast.ExprUnary) error { + if err := comp.compileExpr(t, expr.Value); err != nil { return err } @@ -188,81 +246,91 @@ func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error { return nil } -func (comp *Compiler) compileCallExpr(expr ast.ExprCall) error { - if err := comp.compileExpr(expr.Callee); err != nil { +func (comp *Compiler) compileCallExpr(t *code.Builder, expr ast.ExprCall) error { + if err := comp.compileExpr(t, expr.Callee); err != nil { return err } for i := 0; i < len(expr.Args); i++ { - if err := comp.compileExpr(expr.Args[i]); err != nil { + if err := comp.compileExpr(t, expr.Args[i]); err != nil { return err } } - comp.builder.AppendOp(code.OpCall) - comp.builder.AppendInt(int64(len(expr.Args))) + t.AppendOp(code.OpCall) + t.AppendInt(int64(len(expr.Args))) return nil } -func (comp *Compiler) compileSubscriptionExpr(expr ast.ExprSubscription) error { - if err := comp.compileExpr(expr.Obj); err != nil { +func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubscription) error { + if err := comp.compileExpr(t, expr.Obj); err != nil { return err } - if err := comp.compileExpr(expr.Key); err != nil { + if err := comp.compileExpr(t, expr.Key); err != nil { return err } - comp.builder.AppendOp(code.OpIndex) + t.AppendOp(code.OpIndex) return nil } -func (comp *Compiler) compileGroupExpr(expr ast.ExprGroup) error { - return comp.compileExpr(expr.Value) +func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error { + return comp.compileExpr(t, expr.Value) } -func (comp *Compiler) compileIdentExpr(expr ast.ExprIdent) error { +func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) error { symbol, ok := comp.scopes.Lookup(expr.Value.Value) if !ok { return fmt.Errorf("undefined symbol %s", expr.Value.Value) } // TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.) - comp.builder.AppendOp(code.OpGetLocal) - comp.builder.AppendInt(int64(symbol.localIndex)) + t.AppendOp(code.OpGetLocal) + t.AppendInt(int64(symbol.localIndex)) return nil } -func (comp *Compiler) compileIntLitExpr(expr ast.ExprIntLit) error { - comp.builder.AppendOp(code.OpPushInt) - comp.builder.AppendInt(int64(expr.Value)) +func (comp *Compiler) compileIntLitExpr(t *code.Builder, expr ast.ExprIntLit) error { + t.AppendOp(code.OpPushInt) + t.AppendInt(int64(expr.Value)) return nil } -func (comp *Compiler) compileFloatLitExpr(expr ast.ExprFloatLit) error { - comp.builder.AppendOp(code.OpPushFloat) - comp.builder.AppendFloat(expr.Value) +func (comp *Compiler) compileFloatLitExpr(t *code.Builder, expr ast.ExprFloatLit) error { + t.AppendOp(code.OpPushFloat) + t.AppendFloat(expr.Value) return nil } -func (comp *Compiler) compileStringLitExpr(expr ast.ExprStringLit) error { - comp.builder.AppendOp(code.OpPushString) - comp.builder.AppendString(expr.Value) +func (comp *Compiler) compileStringLitExpr(t *code.Builder, expr ast.ExprStringLit) error { + t.AppendOp(code.OpPushString) + t.AppendString(expr.Value) return nil } -func (comp *Compiler) compileBoolLitExpr(expr ast.ExprBoolLit) error { +func (comp *Compiler) compileBoolLitExpr(t *code.Builder, expr ast.ExprBoolLit) error { if expr.Value { - comp.builder.AppendOp(code.OpPushTrue) + t.AppendOp(code.OpPushTrue) } else { - comp.builder.AppendOp(code.OpPushFalse) + t.AppendOp(code.OpPushFalse) } return nil } -func (comp *Compiler) compileNullLitExpr(expr ast.ExprNullLit) error { - comp.builder.AppendOp(code.OpPushNull) +func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit) error { + t.AppendOp(code.OpPushNull) + return nil +} + +func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error { + for _, stmt := range block.Stmts { + if err := comp.compileStmt(t, stmt); err != nil { + return err + } + } + return nil } diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go index cf0fee2..a6a1ba7 100644 --- a/pkg/lang/compiler/compiler_test.go +++ b/pkg/lang/compiler/compiler_test.go @@ -1,7 +1,6 @@ package compiler_test import ( - "fmt" "jinx/pkg/lang/compiler" "jinx/pkg/lang/parser" "jinx/pkg/lang/scanner" @@ -22,6 +21,7 @@ func TestSimpleAddExpr(t *testing.T) { push_int 1 push_int 2 add + halt ` mustCompileTo(t, src, expected) @@ -44,6 +44,7 @@ func TestOperationOrder(t *testing.T) { sub push_int 4 add + halt ` mustCompileTo(t, grouped, expected) @@ -76,6 +77,8 @@ func TestNestedExpr(t *testing.T) { add sub + + halt ` mustCompileTo(t, src, expected) @@ -103,6 +106,98 @@ func TestVars(t *testing.T) { get_local 0 get_local 1 add + + halt + ` + + mustCompileTo(t, src, expected) +} + +func TestIf(t *testing.T) { + src := ` + ":(" + if 1 <= 5 { + "hello " + "world" + } + ` + + expected := ` + push_string ":(" + + push_int 1 + push_int 5 + lte + + jf @end + + push_string "hello " + push_string "world" + add + + @end: + halt + ` + + mustCompileTo(t, src, expected) +} + +func TestIfElifElse(t *testing.T) { + src := ` + if false { + 1 + } elif true { + 2 + } else { + 3 + } + ` + + expected := ` + push_false + jf @elif + + push_int 1 + jmp @end + + @elif: + push_true + jf @else + push_int 2 + jmp @end + + @else: + push_int 3 + + @end: + halt + ` + + mustCompileTo(t, src, expected) +} + +func TestIfNoElse(t *testing.T) { + src := ` + if false { + 1 + } elif true { + 2 + } + ` + + expected := ` + push_false + jf @elif + + push_int 1 + jmp @end + + @elif: + push_true + jf @end + push_int 2 + + @end: + halt ` mustCompileTo(t, src, expected) @@ -117,9 +212,6 @@ func mustCompileTo(t *testing.T, src, expected string) { program, err := parser.Parse() require.NoError(t, err) - // spew.Dump(program) - fmt.Printf("%#v\n", program) - langCompiler := compiler.New(program) testResult, err := langCompiler.Compile() require.NoError(t, err) diff --git a/pkg/lang/parser/parser_test.go b/pkg/lang/parser/parser_test.go index f48b6a5..b756b1b 100644 --- a/pkg/lang/parser/parser_test.go +++ b/pkg/lang/parser/parser_test.go @@ -384,34 +384,38 @@ func TestIfStmt(t *testing.T) { require.Equal(t, ast.Stmt{ Kind: ast.StmtKindIf, Value: ast.StmtIf{ - Cond: ast.Expr{ - At: source.NewLoc(0, 3), - Kind: ast.ExprKindBoolLit, - Value: ast.ExprBoolLit{Value: false}, - }, - Then: ast.BlockNode{ - At: source.NewLoc(0, 9), - Stmts: []ast.Stmt{ - { - At: source.NewLoc(0, 10), - Kind: ast.StmtKindEmpty, - Value: ast.StmtEmpty{}, + Conds: []ast.CondNode{ + { + At: source.NewLoc(0, 0), + Cond: ast.Expr{ + At: source.NewLoc(0, 3), + Kind: ast.ExprKindBoolLit, + Value: ast.ExprBoolLit{Value: false}, }, - { - At: source.NewLoc(1, 1), - Kind: ast.StmtKindVarDecl, - Value: ast.StmtVarDecl{ - Name: ast.IdentNode{At: source.NewLoc(1, 5), Value: "x"}, - Value: ast.Expr{ - At: source.NewLoc(1, 9), - Kind: ast.ExprKindIntLit, - Value: ast.ExprIntLit{Value: 2}, + Then: ast.BlockNode{ + At: source.NewLoc(0, 9), + Stmts: []ast.Stmt{ + { + At: source.NewLoc(0, 10), + Kind: ast.StmtKindEmpty, + Value: ast.StmtEmpty{}, + }, + { + At: source.NewLoc(1, 1), + Kind: ast.StmtKindVarDecl, + Value: ast.StmtVarDecl{ + Name: ast.IdentNode{At: source.NewLoc(1, 5), Value: "x"}, + Value: ast.Expr{ + At: source.NewLoc(1, 9), + Kind: ast.ExprKindIntLit, + Value: ast.ExprIntLit{Value: 2}, + }, + }, }, }, }, }, }, - Elifs: []ast.CondNode{}, }, }, program.Stmts[0]) } diff --git a/pkg/lang/parser/stmts.go b/pkg/lang/parser/stmts.go index 52e590c..4b1077e 100644 --- a/pkg/lang/parser/stmts.go +++ b/pkg/lang/parser/stmts.go @@ -151,7 +151,9 @@ func (p *Parser) parseIfStmt() (ast.Stmt, error) { return ast.Stmt{}, err } - cond, err := p.parseExpr() + conds := []ast.CondNode{} + + ifCond, err := p.parseExpr() if err != nil { return ast.Stmt{}, err } @@ -161,7 +163,11 @@ func (p *Parser) parseIfStmt() (ast.Stmt, error) { return ast.Stmt{}, err } - elifs := []ast.CondNode{} + conds = append(conds, ast.CondNode{ + At: ifTok.At, + Cond: ifCond, + Then: then, + }) for p.peek().Kind == token.KwElif { elifTok, err := p.expect(token.KwElif) @@ -179,35 +185,37 @@ func (p *Parser) parseIfStmt() (ast.Stmt, error) { return ast.Stmt{}, err } - elifs = append(elifs, ast.CondNode{ + conds = append(conds, ast.CondNode{ At: elifTok.At, Cond: elifCond, Then: elifThen, }) } - elseThen := ast.BlockNode{} if p.peek().Kind == token.KwElse { - _, err := p.expect(token.KwElse) + elseTok, err := p.expect(token.KwElse) if err != nil { return ast.Stmt{}, err } - elseThen, err = p.parseBlock() + elseThen, err := p.parseBlock() if err != nil { return ast.Stmt{}, err } + + conds = append(conds, ast.CondNode{ + At: elseTok.At, + Cond: ast.Expr{}, + Then: elseThen, + }) } return ast.Stmt{ At: ifTok.At, Kind: ast.StmtKindIf, Value: ast.StmtIf{ - Cond: cond, - Then: then, - Elifs: elifs, - Else: elseThen, + Conds: conds, }, }, nil } |
