about summary refs log tree commit diff
path: root/pkg/lang/compiler/compiler.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lang/compiler/compiler.go')
-rw-r--r--pkg/lang/compiler/compiler.go73
1 files changed, 45 insertions, 28 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index f4a900a..cad4328 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -67,56 +67,73 @@ func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl)
 	return nil
 }
 
-func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error {
-	// push_false -> cond // only on ifs and elifs
-	// jf @elif -> condjmp // only on ifs and elifs
-
-	// push_int 1 -> then
-	// jmp @end -> thenjmp // except on last cond
-
-	subUnits := make([]code.Builder, 0, len(ifstmt.Conds))
-
-	totalLength := t.Len()
-	jmpLength := 9 // OP + Uint
-
-	for i, cond := range ifstmt.Conds {
-		thenTarget := code.NewBuilder() // then
+func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
+	// An if statement is composed out of CondNodes.
+	// A cond node can be either the top `if` branch, and `elif` branch
+	// inbetween, or `else` branch at the bottom.
+	// `if` and `elif` are identical, but the Cond expr of an `else` node is an empty ast.Expr.
+	//
+	// Each compiled CondNode consists of 4 parts.
+	// 1. Condition check => Compiled Cond expr, pushes a bool onto the stack
+	//    Example: `push_false` for `if false {}`
+	// 2. Condition jump => If the condition is false, jump to the **next** CondNode, if not last.
+	//    Example: `jf @elif` or `jf @end`
+	// 3. Then block => Anything the user wants to execute.
+	//    Example: `push_int 1` or something
+	// 4. Then jump => Since the condition was true, we have to jump to the end of the CondNode list,
+	//				   preventing other CondNodes from running. This is missing from the last CondNode.
+	//    Example: `jmp @end`
+
+	subUnits := make([]code.Builder, 0, len(ifStmt.Conds))
+
+	totalLength := 0
+	jmpLength := 9 // The length of either of the jump parts: op: 1 + uint: 8 = 9
+
+	for i, cond := range ifStmt.Conds {
+		// Then block
+		thenTarget := code.NewBuilder()
 		if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil {
 			return err
 		}
 
 		totalLength += thenTarget.Len()
-		if i != len(ifstmt.Conds)-1 {
-			totalLength += jmpLength // thenjmp
+		if i != len(ifStmt.Conds)-1 {
+			totalLength += jmpLength
 		}
 
-		subUnitTarget := code.NewBuilder()
+		// Condition check
+		conditionTarget := code.NewBuilder()
 		if !cond.Cond.IsEmpty() {
-			// cond
-			if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil {
+			if err := comp.compileExpr(&conditionTarget, cond.Cond); err != nil {
 				return err
 			}
 
-			totalLength += subUnitTarget.Len() + jmpLength // condjmp
+			totalLength += conditionTarget.Len() + jmpLength // condjmp
 
-			subUnitTarget.AppendOp(code.OpJf)
-			subUnitTarget.AppendInt(int64(totalLength))
+			conditionTarget.AppendOp(code.OpJf)
+			// Condition jump
+			conditionTarget.AppendReferenceToPc(int64(totalLength))
 		}
 
-		subUnitTarget.AppendRaw(thenTarget.Code())
-
-		subUnits = append(subUnits, subUnitTarget)
+		subUnit := conditionTarget
+		subUnit.AppendBuilder(thenTarget)
+		subUnits = append(subUnits, subUnit)
 	}
 
+	result := code.NewBuilder()
+
+	// Then jumps
 	for i, subUnit := range subUnits {
-		if i != len(ifstmt.Conds)-1 {
+		if i != len(ifStmt.Conds)-1 {
 			subUnit.AppendOp(code.OpJmp)
-			subUnit.AppendInt(int64(totalLength))
+			subUnit.AppendReferenceToPc(int64(totalLength))
 		}
 
-		t.AppendRaw(subUnit.Code())
+		result.AppendBuilderWithoutAdjustingReferences(subUnit)
 	}
 
+	t.AppendBuilder(result)
+
 	return nil
 }