diff options
Diffstat (limited to 'pkg/lang/compiler')
| -rw-r--r-- | pkg/lang/compiler/compiler.go | 73 | ||||
| -rw-r--r-- | pkg/lang/compiler/compiler_test.go | 31 |
2 files changed, 76 insertions, 28 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go index f4a900a..cad4328 100644 --- a/pkg/lang/compiler/compiler.go +++ b/pkg/lang/compiler/compiler.go @@ -67,56 +67,73 @@ func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) return nil } -func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error { - // push_false -> cond // only on ifs and elifs - // jf @elif -> condjmp // only on ifs and elifs - - // push_int 1 -> then - // jmp @end -> thenjmp // except on last cond - - subUnits := make([]code.Builder, 0, len(ifstmt.Conds)) - - totalLength := t.Len() - jmpLength := 9 // OP + Uint - - for i, cond := range ifstmt.Conds { - thenTarget := code.NewBuilder() // then +func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error { + // An if statement is composed out of CondNodes. + // A cond node can be either the top `if` branch, and `elif` branch + // inbetween, or `else` branch at the bottom. + // `if` and `elif` are identical, but the Cond expr of an `else` node is an empty ast.Expr. + // + // Each compiled CondNode consists of 4 parts. + // 1. Condition check => Compiled Cond expr, pushes a bool onto the stack + // Example: `push_false` for `if false {}` + // 2. Condition jump => If the condition is false, jump to the **next** CondNode, if not last. + // Example: `jf @elif` or `jf @end` + // 3. Then block => Anything the user wants to execute. + // Example: `push_int 1` or something + // 4. Then jump => Since the condition was true, we have to jump to the end of the CondNode list, + // preventing other CondNodes from running. This is missing from the last CondNode. + // Example: `jmp @end` + + subUnits := make([]code.Builder, 0, len(ifStmt.Conds)) + + totalLength := 0 + jmpLength := 9 // The length of either of the jump parts: op: 1 + uint: 8 = 9 + + for i, cond := range ifStmt.Conds { + // Then block + thenTarget := code.NewBuilder() if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil { return err } totalLength += thenTarget.Len() - if i != len(ifstmt.Conds)-1 { - totalLength += jmpLength // thenjmp + if i != len(ifStmt.Conds)-1 { + totalLength += jmpLength } - subUnitTarget := code.NewBuilder() + // Condition check + conditionTarget := code.NewBuilder() if !cond.Cond.IsEmpty() { - // cond - if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil { + if err := comp.compileExpr(&conditionTarget, cond.Cond); err != nil { return err } - totalLength += subUnitTarget.Len() + jmpLength // condjmp + totalLength += conditionTarget.Len() + jmpLength // condjmp - subUnitTarget.AppendOp(code.OpJf) - subUnitTarget.AppendInt(int64(totalLength)) + conditionTarget.AppendOp(code.OpJf) + // Condition jump + conditionTarget.AppendReferenceToPc(int64(totalLength)) } - subUnitTarget.AppendRaw(thenTarget.Code()) - - subUnits = append(subUnits, subUnitTarget) + subUnit := conditionTarget + subUnit.AppendBuilder(thenTarget) + subUnits = append(subUnits, subUnit) } + result := code.NewBuilder() + + // Then jumps for i, subUnit := range subUnits { - if i != len(ifstmt.Conds)-1 { + if i != len(ifStmt.Conds)-1 { subUnit.AppendOp(code.OpJmp) - subUnit.AppendInt(int64(totalLength)) + subUnit.AppendReferenceToPc(int64(totalLength)) } - t.AppendRaw(subUnit.Code()) + result.AppendBuilderWithoutAdjustingReferences(subUnit) } + t.AppendBuilder(result) + return nil } diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go index a6a1ba7..88413c7 100644 --- a/pkg/lang/compiler/compiler_test.go +++ b/pkg/lang/compiler/compiler_test.go @@ -203,6 +203,37 @@ func TestIfNoElse(t *testing.T) { mustCompileTo(t, src, expected) } +func TestNestedIfs(t *testing.T) { + src := ` + if true { + if false { + 1 + } else { + 2 + } + } + ` + + expected := ` + push_true + jf @end + + push_false + jf @else + + push_int 1 + jmp @end + + @else: + push_int 2 + + @end: + halt + ` + + mustCompileTo(t, src, expected) +} + func mustCompileTo(t *testing.T, src, expected string) { scanner := scanner.New(strings.NewReader(src)) tokens, err := scanner.Scan() |
