about summary refs log tree commit diff
path: root/pkg/lang/compiler
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-07-03 18:11:55 +0200
committerMel <einebeere@gmail.com>2022-07-03 18:11:55 +0200
commited9a0c8f0f3c1eed3582d722935cd1df1d055afd (patch)
treede68dd7510f03457693593d7dbd8636804f89ded /pkg/lang/compiler
parent4774f6373d3e41acba54cb4c63ca51f1b3de2ddd (diff)
downloadjinx-ed9a0c8f0f3c1eed3582d722935cd1df1d055afd.tar.zst
jinx-ed9a0c8f0f3c1eed3582d722935cd1df1d055afd.zip
Compile If Stmts
Diffstat (limited to 'pkg/lang/compiler')
-rw-r--r--pkg/lang/compiler/compiler.go214
-rw-r--r--pkg/lang/compiler/compiler_test.go100
2 files changed, 237 insertions, 77 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index 4d7f4a1..f4a900a 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -7,90 +7,148 @@ import (
 )
 
 type Compiler struct {
-	ast     ast.Program
-	builder code.Builder
+	ast ast.Program
 
 	scopes ScopeChain
 }
 
 func New(ast ast.Program) *Compiler {
 	return &Compiler{
-		ast:     ast,
-		builder: code.NewBuilder(),
+		ast: ast,
 
 		scopes: NewScopeChain(),
 	}
 }
 
 func (comp *Compiler) Compile() (code.Code, error) {
+	target := code.NewBuilder()
+
 	for _, stmt := range comp.ast.Stmts {
-		if err := comp.compileStmt(stmt); err != nil {
+		if err := comp.compileStmt(&target, stmt); err != nil {
 			return code.Code{}, err
 		}
 	}
 
-	return comp.builder.Build(), nil
+	target.AppendOp(code.OpHalt)
+
+	return target.Build(), nil
 }
 
-func (comp *Compiler) compileStmt(stmt ast.Stmt) error {
+func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	var err error
 	switch stmt.Kind {
 	case ast.StmtKindEmpty:
 		// Do nothing.
 	case ast.StmtKindVarDecl:
 		decl := stmt.Value.(ast.StmtVarDecl)
-		err = comp.compileVarDeclStmt(decl)
+		err = comp.compileVarDeclStmt(t, decl)
+	case ast.StmtKindIf:
+		ifstmt := stmt.Value.(ast.StmtIf)
+		err = comp.compileIfStmt(t, ifstmt)
 	case ast.StmtKindExpr:
 		expr := stmt.Value.(ast.StmtExpr).Value
-		err = comp.compileExpr(expr)
+		err = comp.compileExpr(t, expr)
 	default:
-		panic("statements other than expressions, variable declarations, var and empties not implemented")
+		panic(fmt.Errorf("statement of kind %v not implemented", stmt.Kind))
 	}
 
 	return err
 }
 
-func (comp *Compiler) compileVarDeclStmt(decl ast.StmtVarDecl) error {
+func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
 	if !comp.scopes.Declare(decl.Name.Value) {
 		return fmt.Errorf("variable %s already declared", decl.Name.Value)
 	}
 
-	if err := comp.compileExpr(decl.Value); err != nil {
+	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
 	}
 
 	return nil
 }
 
-func (comp *Compiler) compileExpr(expr ast.Expr) error {
+func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error {
+	// push_false -> cond // only on ifs and elifs
+	// jf @elif -> condjmp // only on ifs and elifs
+
+	// push_int 1 -> then
+	// jmp @end -> thenjmp // except on last cond
+
+	subUnits := make([]code.Builder, 0, len(ifstmt.Conds))
+
+	totalLength := t.Len()
+	jmpLength := 9 // OP + Uint
+
+	for i, cond := range ifstmt.Conds {
+		thenTarget := code.NewBuilder() // then
+		if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil {
+			return err
+		}
+
+		totalLength += thenTarget.Len()
+		if i != len(ifstmt.Conds)-1 {
+			totalLength += jmpLength // thenjmp
+		}
+
+		subUnitTarget := code.NewBuilder()
+		if !cond.Cond.IsEmpty() {
+			// cond
+			if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil {
+				return err
+			}
+
+			totalLength += subUnitTarget.Len() + jmpLength // condjmp
+
+			subUnitTarget.AppendOp(code.OpJf)
+			subUnitTarget.AppendInt(int64(totalLength))
+		}
+
+		subUnitTarget.AppendRaw(thenTarget.Code())
+
+		subUnits = append(subUnits, subUnitTarget)
+	}
+
+	for i, subUnit := range subUnits {
+		if i != len(ifstmt.Conds)-1 {
+			subUnit.AppendOp(code.OpJmp)
+			subUnit.AppendInt(int64(totalLength))
+		}
+
+		t.AppendRaw(subUnit.Code())
+	}
+
+	return nil
+}
+
+func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
-		return comp.compileBinaryExpr(expr.Value.(ast.ExprBinary))
+		return comp.compileBinaryExpr(t, expr.Value.(ast.ExprBinary))
 	case ast.ExprKindUnary:
-		return comp.compileUnaryExpr(expr.Value.(ast.ExprUnary))
+		return comp.compileUnaryExpr(t, expr.Value.(ast.ExprUnary))
 	case ast.ExprKindCall:
-		return comp.compileCallExpr(expr.Value.(ast.ExprCall))
+		return comp.compileCallExpr(t, expr.Value.(ast.ExprCall))
 	case ast.ExprKindSubscription:
-		return comp.compileSubscriptionExpr(expr.Value.(ast.ExprSubscription))
+		return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription))
 
 	case ast.ExprKindGroup:
-		return comp.compileGroupExpr(expr.Value.(ast.ExprGroup))
+		return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup))
 	case ast.ExprKindFnLit:
 		panic("not implemented")
 	case ast.ExprKindArrayLit:
 		panic("not implemented")
 	case ast.ExprKindIdent:
-		return comp.compileIdentExpr(expr.Value.(ast.ExprIdent))
+		return comp.compileIdentExpr(t, expr.Value.(ast.ExprIdent))
 	case ast.ExprKindIntLit:
-		return comp.compileIntLitExpr(expr.Value.(ast.ExprIntLit))
+		return comp.compileIntLitExpr(t, expr.Value.(ast.ExprIntLit))
 	case ast.ExprKindFloatLit:
-		return comp.compileFloatLitExpr(expr.Value.(ast.ExprFloatLit))
+		return comp.compileFloatLitExpr(t, expr.Value.(ast.ExprFloatLit))
 	case ast.ExprKindStringLit:
-		return comp.compileStringLitExpr(expr.Value.(ast.ExprStringLit))
+		return comp.compileStringLitExpr(t, expr.Value.(ast.ExprStringLit))
 	case ast.ExprKindBoolLit:
-		return comp.compileBoolLitExpr(expr.Value.(ast.ExprBoolLit))
+		return comp.compileBoolLitExpr(t, expr.Value.(ast.ExprBoolLit))
 	case ast.ExprKindNullLit:
-		return comp.compileNullLitExpr(expr.Value.(ast.ExprNullLit))
+		return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit))
 	case ast.ExprKindThis:
 		panic("not implemented")
 	default:
@@ -98,50 +156,50 @@ func (comp *Compiler) compileExpr(expr ast.Expr) error {
 	}
 }
 
-func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
+func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) error {
 	if expr.Op == ast.BinOpAssign {
-		return comp.compileAssignExpr(expr)
+		return comp.compileAssignExpr(t, expr)
 	}
 
-	if err := comp.compileExpr(expr.Left); err != nil {
+	if err := comp.compileExpr(t, expr.Left); err != nil {
 		return err
 	}
 
-	if err := comp.compileExpr(expr.Right); err != nil {
+	if err := comp.compileExpr(t, expr.Right); err != nil {
 		return err
 	}
 
 	switch expr.Op {
 	case ast.BinOpPlus:
-		comp.builder.AppendOp(code.OpAdd)
+		t.AppendOp(code.OpAdd)
 	case ast.BinOpMinus:
-		comp.builder.AppendOp(code.OpSub)
+		t.AppendOp(code.OpSub)
 	case ast.BinOpStar:
-		// comp.builder.AppendOp(code.OpMul)
+		// t.AppendOp(code.OpMul)
 		panic("not implemented")
 	case ast.BinOpSlash:
-		// comp.builder.AppendOp(code.OpDiv)
+		// t.AppendOp(code.OpDiv)
 		panic("not implemented")
 	case ast.BinOpPercent:
-		// comp.builder.AppendOp(code.OpMod)
+		// t.AppendOp(code.OpMod)
 		panic("not implemented")
 
 	case ast.BinOpEq:
-		// comp.builder.AppendOp(code.OpEq)
+		// t.AppendOp(code.OpEq)
 		panic("not implemented")
 	case ast.BinOpNeq:
-		// comp.builder.AppendOp(code.OpNeq)
+		// t.AppendOp(code.OpNeq)
 		panic("not implemented")
 	case ast.BinOpLt:
-		// comp.builder.AppendOp(code.OpLt)
+		// t.AppendOp(code.OpLt)
 		panic("not implemented")
 	case ast.BinOpLte:
-		comp.builder.AppendOp(code.OpLte)
+		t.AppendOp(code.OpLte)
 	case ast.BinOpGt:
-		// comp.builder.AppendOp(code.OpGt)
+		// t.AppendOp(code.OpGt)
 		panic("not implemented")
 	case ast.BinOpGte:
-		// comp.builder.AppendOp(code.OpGte)
+		// t.AppendOp(code.OpGte)
 		panic("not implemented")
 	default:
 		panic("unknown binary operator")
@@ -150,7 +208,7 @@ func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
 	return nil
 }
 
-func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error {
+func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error {
 	if expr.Left.Kind != ast.ExprKindIdent {
 		return fmt.Errorf("lvalues other than identifiers not implemented")
 	}
@@ -161,18 +219,18 @@ func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error {
 		return fmt.Errorf("variable %s not declared", name)
 	}
 
-	if err := comp.compileExpr(expr.Right); err != nil {
+	if err := comp.compileExpr(t, expr.Right); err != nil {
 		return err
 	}
 
-	comp.builder.AppendOp(code.OpSetLocal)
-	comp.builder.AppendInt(int64(symbol.localIndex))
+	t.AppendOp(code.OpSetLocal)
+	t.AppendInt(int64(symbol.localIndex))
 
 	return nil
 }
 
-func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error {
-	if err := comp.compileExpr(expr.Value); err != nil {
+func (comp *Compiler) compileUnaryExpr(t *code.Builder, expr ast.ExprUnary) error {
+	if err := comp.compileExpr(t, expr.Value); err != nil {
 		return err
 	}
 
@@ -188,81 +246,91 @@ func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error {
 	return nil
 }
 
-func (comp *Compiler) compileCallExpr(expr ast.ExprCall) error {
-	if err := comp.compileExpr(expr.Callee); err != nil {
+func (comp *Compiler) compileCallExpr(t *code.Builder, expr ast.ExprCall) error {
+	if err := comp.compileExpr(t, expr.Callee); err != nil {
 		return err
 	}
 
 	for i := 0; i < len(expr.Args); i++ {
-		if err := comp.compileExpr(expr.Args[i]); err != nil {
+		if err := comp.compileExpr(t, expr.Args[i]); err != nil {
 			return err
 		}
 	}
 
-	comp.builder.AppendOp(code.OpCall)
-	comp.builder.AppendInt(int64(len(expr.Args)))
+	t.AppendOp(code.OpCall)
+	t.AppendInt(int64(len(expr.Args)))
 
 	return nil
 }
 
-func (comp *Compiler) compileSubscriptionExpr(expr ast.ExprSubscription) error {
-	if err := comp.compileExpr(expr.Obj); err != nil {
+func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubscription) error {
+	if err := comp.compileExpr(t, expr.Obj); err != nil {
 		return err
 	}
 
-	if err := comp.compileExpr(expr.Key); err != nil {
+	if err := comp.compileExpr(t, expr.Key); err != nil {
 		return err
 	}
 
-	comp.builder.AppendOp(code.OpIndex)
+	t.AppendOp(code.OpIndex)
 	return nil
 }
 
-func (comp *Compiler) compileGroupExpr(expr ast.ExprGroup) error {
-	return comp.compileExpr(expr.Value)
+func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error {
+	return comp.compileExpr(t, expr.Value)
 }
 
-func (comp *Compiler) compileIdentExpr(expr ast.ExprIdent) error {
+func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) error {
 	symbol, ok := comp.scopes.Lookup(expr.Value.Value)
 	if !ok {
 		return fmt.Errorf("undefined symbol %s", expr.Value.Value)
 	}
 
 	// TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.)
-	comp.builder.AppendOp(code.OpGetLocal)
-	comp.builder.AppendInt(int64(symbol.localIndex))
+	t.AppendOp(code.OpGetLocal)
+	t.AppendInt(int64(symbol.localIndex))
 
 	return nil
 }
 
-func (comp *Compiler) compileIntLitExpr(expr ast.ExprIntLit) error {
-	comp.builder.AppendOp(code.OpPushInt)
-	comp.builder.AppendInt(int64(expr.Value))
+func (comp *Compiler) compileIntLitExpr(t *code.Builder, expr ast.ExprIntLit) error {
+	t.AppendOp(code.OpPushInt)
+	t.AppendInt(int64(expr.Value))
 	return nil
 }
 
-func (comp *Compiler) compileFloatLitExpr(expr ast.ExprFloatLit) error {
-	comp.builder.AppendOp(code.OpPushFloat)
-	comp.builder.AppendFloat(expr.Value)
+func (comp *Compiler) compileFloatLitExpr(t *code.Builder, expr ast.ExprFloatLit) error {
+	t.AppendOp(code.OpPushFloat)
+	t.AppendFloat(expr.Value)
 	return nil
 }
 
-func (comp *Compiler) compileStringLitExpr(expr ast.ExprStringLit) error {
-	comp.builder.AppendOp(code.OpPushString)
-	comp.builder.AppendString(expr.Value)
+func (comp *Compiler) compileStringLitExpr(t *code.Builder, expr ast.ExprStringLit) error {
+	t.AppendOp(code.OpPushString)
+	t.AppendString(expr.Value)
 	return nil
 }
 
-func (comp *Compiler) compileBoolLitExpr(expr ast.ExprBoolLit) error {
+func (comp *Compiler) compileBoolLitExpr(t *code.Builder, expr ast.ExprBoolLit) error {
 	if expr.Value {
-		comp.builder.AppendOp(code.OpPushTrue)
+		t.AppendOp(code.OpPushTrue)
 	} else {
-		comp.builder.AppendOp(code.OpPushFalse)
+		t.AppendOp(code.OpPushFalse)
 	}
 	return nil
 }
 
-func (comp *Compiler) compileNullLitExpr(expr ast.ExprNullLit) error {
-	comp.builder.AppendOp(code.OpPushNull)
+func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit) error {
+	t.AppendOp(code.OpPushNull)
+	return nil
+}
+
+func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error {
+	for _, stmt := range block.Stmts {
+		if err := comp.compileStmt(t, stmt); err != nil {
+			return err
+		}
+	}
+
 	return nil
 }
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index cf0fee2..a6a1ba7 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -1,7 +1,6 @@
 package compiler_test
 
 import (
-	"fmt"
 	"jinx/pkg/lang/compiler"
 	"jinx/pkg/lang/parser"
 	"jinx/pkg/lang/scanner"
@@ -22,6 +21,7 @@ func TestSimpleAddExpr(t *testing.T) {
 	push_int 1
 	push_int 2
 	add
+	halt
 	`
 
 	mustCompileTo(t, src, expected)
@@ -44,6 +44,7 @@ func TestOperationOrder(t *testing.T) {
 	sub
 	push_int 4
 	add
+	halt
 	`
 
 	mustCompileTo(t, grouped, expected)
@@ -76,6 +77,8 @@ func TestNestedExpr(t *testing.T) {
 
 	add
 	sub
+
+	halt
 	`
 
 	mustCompileTo(t, src, expected)
@@ -103,6 +106,98 @@ func TestVars(t *testing.T) {
 	get_local 0
 	get_local 1
 	add
+
+	halt
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
+func TestIf(t *testing.T) {
+	src := `
+	":("
+	if 1 <= 5 {
+		"hello " + "world"
+	}
+	`
+
+	expected := `
+	push_string ":("
+
+	push_int 1
+	push_int 5
+	lte
+
+	jf @end
+
+	push_string "hello "
+	push_string "world"
+	add
+
+	@end:
+	halt
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
+func TestIfElifElse(t *testing.T) {
+	src := `
+		if false {
+			1
+		} elif true {
+			2
+		} else {
+			3
+		}
+	`
+
+	expected := `
+	push_false
+	jf @elif
+
+	push_int 1
+	jmp @end
+
+	@elif:
+	push_true
+	jf @else
+	push_int 2
+	jmp @end
+
+	@else:
+	push_int 3
+
+	@end:
+	halt
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
+func TestIfNoElse(t *testing.T) {
+	src := `
+		if false {
+			1
+		} elif true {
+			2
+		}
+	`
+
+	expected := `
+	push_false
+	jf @elif
+
+	push_int 1
+	jmp @end
+
+	@elif:
+	push_true
+	jf @end
+	push_int 2
+
+	@end:
+	halt
 	`
 
 	mustCompileTo(t, src, expected)
@@ -117,9 +212,6 @@ func mustCompileTo(t *testing.T, src, expected string) {
 	program, err := parser.Parse()
 	require.NoError(t, err)
 
-	// spew.Dump(program)
-	fmt.Printf("%#v\n", program)
-
 	langCompiler := compiler.New(program)
 	testResult, err := langCompiler.Compile()
 	require.NoError(t, err)