about summary refs log tree commit diff
path: root/pkg/lang/compiler/compiler.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lang/compiler/compiler.go')
-rw-r--r--pkg/lang/compiler/compiler.go214
1 files changed, 141 insertions, 73 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index 4d7f4a1..f4a900a 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -7,90 +7,148 @@ import (
 )
 
 type Compiler struct {
-	ast     ast.Program
-	builder code.Builder
+	ast ast.Program
 
 	scopes ScopeChain
 }
 
 func New(ast ast.Program) *Compiler {
 	return &Compiler{
-		ast:     ast,
-		builder: code.NewBuilder(),
+		ast: ast,
 
 		scopes: NewScopeChain(),
 	}
 }
 
 func (comp *Compiler) Compile() (code.Code, error) {
+	target := code.NewBuilder()
+
 	for _, stmt := range comp.ast.Stmts {
-		if err := comp.compileStmt(stmt); err != nil {
+		if err := comp.compileStmt(&target, stmt); err != nil {
 			return code.Code{}, err
 		}
 	}
 
-	return comp.builder.Build(), nil
+	target.AppendOp(code.OpHalt)
+
+	return target.Build(), nil
 }
 
-func (comp *Compiler) compileStmt(stmt ast.Stmt) error {
+func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	var err error
 	switch stmt.Kind {
 	case ast.StmtKindEmpty:
 		// Do nothing.
 	case ast.StmtKindVarDecl:
 		decl := stmt.Value.(ast.StmtVarDecl)
-		err = comp.compileVarDeclStmt(decl)
+		err = comp.compileVarDeclStmt(t, decl)
+	case ast.StmtKindIf:
+		ifstmt := stmt.Value.(ast.StmtIf)
+		err = comp.compileIfStmt(t, ifstmt)
 	case ast.StmtKindExpr:
 		expr := stmt.Value.(ast.StmtExpr).Value
-		err = comp.compileExpr(expr)
+		err = comp.compileExpr(t, expr)
 	default:
-		panic("statements other than expressions, variable declarations, var and empties not implemented")
+		panic(fmt.Errorf("statement of kind %v not implemented", stmt.Kind))
 	}
 
 	return err
 }
 
-func (comp *Compiler) compileVarDeclStmt(decl ast.StmtVarDecl) error {
+func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
 	if !comp.scopes.Declare(decl.Name.Value) {
 		return fmt.Errorf("variable %s already declared", decl.Name.Value)
 	}
 
-	if err := comp.compileExpr(decl.Value); err != nil {
+	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
 	}
 
 	return nil
 }
 
-func (comp *Compiler) compileExpr(expr ast.Expr) error {
+func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error {
+	// push_false -> cond // only on ifs and elifs
+	// jf @elif -> condjmp // only on ifs and elifs
+
+	// push_int 1 -> then
+	// jmp @end -> thenjmp // except on last cond
+
+	subUnits := make([]code.Builder, 0, len(ifstmt.Conds))
+
+	totalLength := t.Len()
+	jmpLength := 9 // OP + Uint
+
+	for i, cond := range ifstmt.Conds {
+		thenTarget := code.NewBuilder() // then
+		if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil {
+			return err
+		}
+
+		totalLength += thenTarget.Len()
+		if i != len(ifstmt.Conds)-1 {
+			totalLength += jmpLength // thenjmp
+		}
+
+		subUnitTarget := code.NewBuilder()
+		if !cond.Cond.IsEmpty() {
+			// cond
+			if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil {
+				return err
+			}
+
+			totalLength += subUnitTarget.Len() + jmpLength // condjmp
+
+			subUnitTarget.AppendOp(code.OpJf)
+			subUnitTarget.AppendInt(int64(totalLength))
+		}
+
+		subUnitTarget.AppendRaw(thenTarget.Code())
+
+		subUnits = append(subUnits, subUnitTarget)
+	}
+
+	for i, subUnit := range subUnits {
+		if i != len(ifstmt.Conds)-1 {
+			subUnit.AppendOp(code.OpJmp)
+			subUnit.AppendInt(int64(totalLength))
+		}
+
+		t.AppendRaw(subUnit.Code())
+	}
+
+	return nil
+}
+
+func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
-		return comp.compileBinaryExpr(expr.Value.(ast.ExprBinary))
+		return comp.compileBinaryExpr(t, expr.Value.(ast.ExprBinary))
 	case ast.ExprKindUnary:
-		return comp.compileUnaryExpr(expr.Value.(ast.ExprUnary))
+		return comp.compileUnaryExpr(t, expr.Value.(ast.ExprUnary))
 	case ast.ExprKindCall:
-		return comp.compileCallExpr(expr.Value.(ast.ExprCall))
+		return comp.compileCallExpr(t, expr.Value.(ast.ExprCall))
 	case ast.ExprKindSubscription:
-		return comp.compileSubscriptionExpr(expr.Value.(ast.ExprSubscription))
+		return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription))
 
 	case ast.ExprKindGroup:
-		return comp.compileGroupExpr(expr.Value.(ast.ExprGroup))
+		return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup))
 	case ast.ExprKindFnLit:
 		panic("not implemented")
 	case ast.ExprKindArrayLit:
 		panic("not implemented")
 	case ast.ExprKindIdent:
-		return comp.compileIdentExpr(expr.Value.(ast.ExprIdent))
+		return comp.compileIdentExpr(t, expr.Value.(ast.ExprIdent))
 	case ast.ExprKindIntLit:
-		return comp.compileIntLitExpr(expr.Value.(ast.ExprIntLit))
+		return comp.compileIntLitExpr(t, expr.Value.(ast.ExprIntLit))
 	case ast.ExprKindFloatLit:
-		return comp.compileFloatLitExpr(expr.Value.(ast.ExprFloatLit))
+		return comp.compileFloatLitExpr(t, expr.Value.(ast.ExprFloatLit))
 	case ast.ExprKindStringLit:
-		return comp.compileStringLitExpr(expr.Value.(ast.ExprStringLit))
+		return comp.compileStringLitExpr(t, expr.Value.(ast.ExprStringLit))
 	case ast.ExprKindBoolLit:
-		return comp.compileBoolLitExpr(expr.Value.(ast.ExprBoolLit))
+		return comp.compileBoolLitExpr(t, expr.Value.(ast.ExprBoolLit))
 	case ast.ExprKindNullLit:
-		return comp.compileNullLitExpr(expr.Value.(ast.ExprNullLit))
+		return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit))
 	case ast.ExprKindThis:
 		panic("not implemented")
 	default:
@@ -98,50 +156,50 @@ func (comp *Compiler) compileExpr(expr ast.Expr) error {
 	}
 }
 
-func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
+func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) error {
 	if expr.Op == ast.BinOpAssign {
-		return comp.compileAssignExpr(expr)
+		return comp.compileAssignExpr(t, expr)
 	}
 
-	if err := comp.compileExpr(expr.Left); err != nil {
+	if err := comp.compileExpr(t, expr.Left); err != nil {
 		return err
 	}
 
-	if err := comp.compileExpr(expr.Right); err != nil {
+	if err := comp.compileExpr(t, expr.Right); err != nil {
 		return err
 	}
 
 	switch expr.Op {
 	case ast.BinOpPlus:
-		comp.builder.AppendOp(code.OpAdd)
+		t.AppendOp(code.OpAdd)
 	case ast.BinOpMinus:
-		comp.builder.AppendOp(code.OpSub)
+		t.AppendOp(code.OpSub)
 	case ast.BinOpStar:
-		// comp.builder.AppendOp(code.OpMul)
+		// t.AppendOp(code.OpMul)
 		panic("not implemented")
 	case ast.BinOpSlash:
-		// comp.builder.AppendOp(code.OpDiv)
+		// t.AppendOp(code.OpDiv)
 		panic("not implemented")
 	case ast.BinOpPercent:
-		// comp.builder.AppendOp(code.OpMod)
+		// t.AppendOp(code.OpMod)
 		panic("not implemented")
 
 	case ast.BinOpEq:
-		// comp.builder.AppendOp(code.OpEq)
+		// t.AppendOp(code.OpEq)
 		panic("not implemented")
 	case ast.BinOpNeq:
-		// comp.builder.AppendOp(code.OpNeq)
+		// t.AppendOp(code.OpNeq)
 		panic("not implemented")
 	case ast.BinOpLt:
-		// comp.builder.AppendOp(code.OpLt)
+		// t.AppendOp(code.OpLt)
 		panic("not implemented")
 	case ast.BinOpLte:
-		comp.builder.AppendOp(code.OpLte)
+		t.AppendOp(code.OpLte)
 	case ast.BinOpGt:
-		// comp.builder.AppendOp(code.OpGt)
+		// t.AppendOp(code.OpGt)
 		panic("not implemented")
 	case ast.BinOpGte:
-		// comp.builder.AppendOp(code.OpGte)
+		// t.AppendOp(code.OpGte)
 		panic("not implemented")
 	default:
 		panic("unknown binary operator")
@@ -150,7 +208,7 @@ func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
 	return nil
 }
 
-func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error {
+func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error {
 	if expr.Left.Kind != ast.ExprKindIdent {
 		return fmt.Errorf("lvalues other than identifiers not implemented")
 	}
@@ -161,18 +219,18 @@ func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error {
 		return fmt.Errorf("variable %s not declared", name)
 	}
 
-	if err := comp.compileExpr(expr.Right); err != nil {
+	if err := comp.compileExpr(t, expr.Right); err != nil {
 		return err
 	}
 
-	comp.builder.AppendOp(code.OpSetLocal)
-	comp.builder.AppendInt(int64(symbol.localIndex))
+	t.AppendOp(code.OpSetLocal)
+	t.AppendInt(int64(symbol.localIndex))
 
 	return nil
 }
 
-func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error {
-	if err := comp.compileExpr(expr.Value); err != nil {
+func (comp *Compiler) compileUnaryExpr(t *code.Builder, expr ast.ExprUnary) error {
+	if err := comp.compileExpr(t, expr.Value); err != nil {
 		return err
 	}
 
@@ -188,81 +246,91 @@ func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error {
 	return nil
 }
 
-func (comp *Compiler) compileCallExpr(expr ast.ExprCall) error {
-	if err := comp.compileExpr(expr.Callee); err != nil {
+func (comp *Compiler) compileCallExpr(t *code.Builder, expr ast.ExprCall) error {
+	if err := comp.compileExpr(t, expr.Callee); err != nil {
 		return err
 	}
 
 	for i := 0; i < len(expr.Args); i++ {
-		if err := comp.compileExpr(expr.Args[i]); err != nil {
+		if err := comp.compileExpr(t, expr.Args[i]); err != nil {
 			return err
 		}
 	}
 
-	comp.builder.AppendOp(code.OpCall)
-	comp.builder.AppendInt(int64(len(expr.Args)))
+	t.AppendOp(code.OpCall)
+	t.AppendInt(int64(len(expr.Args)))
 
 	return nil
 }
 
-func (comp *Compiler) compileSubscriptionExpr(expr ast.ExprSubscription) error {
-	if err := comp.compileExpr(expr.Obj); err != nil {
+func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubscription) error {
+	if err := comp.compileExpr(t, expr.Obj); err != nil {
 		return err
 	}
 
-	if err := comp.compileExpr(expr.Key); err != nil {
+	if err := comp.compileExpr(t, expr.Key); err != nil {
 		return err
 	}
 
-	comp.builder.AppendOp(code.OpIndex)
+	t.AppendOp(code.OpIndex)
 	return nil
 }
 
-func (comp *Compiler) compileGroupExpr(expr ast.ExprGroup) error {
-	return comp.compileExpr(expr.Value)
+func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error {
+	return comp.compileExpr(t, expr.Value)
 }
 
-func (comp *Compiler) compileIdentExpr(expr ast.ExprIdent) error {
+func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) error {
 	symbol, ok := comp.scopes.Lookup(expr.Value.Value)
 	if !ok {
 		return fmt.Errorf("undefined symbol %s", expr.Value.Value)
 	}
 
 	// TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.)
-	comp.builder.AppendOp(code.OpGetLocal)
-	comp.builder.AppendInt(int64(symbol.localIndex))
+	t.AppendOp(code.OpGetLocal)
+	t.AppendInt(int64(symbol.localIndex))
 
 	return nil
 }
 
-func (comp *Compiler) compileIntLitExpr(expr ast.ExprIntLit) error {
-	comp.builder.AppendOp(code.OpPushInt)
-	comp.builder.AppendInt(int64(expr.Value))
+func (comp *Compiler) compileIntLitExpr(t *code.Builder, expr ast.ExprIntLit) error {
+	t.AppendOp(code.OpPushInt)
+	t.AppendInt(int64(expr.Value))
 	return nil
 }
 
-func (comp *Compiler) compileFloatLitExpr(expr ast.ExprFloatLit) error {
-	comp.builder.AppendOp(code.OpPushFloat)
-	comp.builder.AppendFloat(expr.Value)
+func (comp *Compiler) compileFloatLitExpr(t *code.Builder, expr ast.ExprFloatLit) error {
+	t.AppendOp(code.OpPushFloat)
+	t.AppendFloat(expr.Value)
 	return nil
 }
 
-func (comp *Compiler) compileStringLitExpr(expr ast.ExprStringLit) error {
-	comp.builder.AppendOp(code.OpPushString)
-	comp.builder.AppendString(expr.Value)
+func (comp *Compiler) compileStringLitExpr(t *code.Builder, expr ast.ExprStringLit) error {
+	t.AppendOp(code.OpPushString)
+	t.AppendString(expr.Value)
 	return nil
 }
 
-func (comp *Compiler) compileBoolLitExpr(expr ast.ExprBoolLit) error {
+func (comp *Compiler) compileBoolLitExpr(t *code.Builder, expr ast.ExprBoolLit) error {
 	if expr.Value {
-		comp.builder.AppendOp(code.OpPushTrue)
+		t.AppendOp(code.OpPushTrue)
 	} else {
-		comp.builder.AppendOp(code.OpPushFalse)
+		t.AppendOp(code.OpPushFalse)
 	}
 	return nil
 }
 
-func (comp *Compiler) compileNullLitExpr(expr ast.ExprNullLit) error {
-	comp.builder.AppendOp(code.OpPushNull)
+func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit) error {
+	t.AppendOp(code.OpPushNull)
+	return nil
+}
+
+func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error {
+	for _, stmt := range block.Stmts {
+		if err := comp.compileStmt(t, stmt); err != nil {
+			return err
+		}
+	}
+
 	return nil
 }