about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lang/compiler/compiler.go73
-rw-r--r--pkg/lang/compiler/compiler_test.go31
-rw-r--r--pkg/lang/vm/code/builder.go77
-rw-r--r--pkg/libs/rangemap/rangemap.go3
4 files changed, 146 insertions, 38 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index f4a900a..cad4328 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -67,56 +67,73 @@ func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl)
 	return nil
 }
 
-func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error {
-	// push_false -> cond // only on ifs and elifs
-	// jf @elif -> condjmp // only on ifs and elifs
-
-	// push_int 1 -> then
-	// jmp @end -> thenjmp // except on last cond
-
-	subUnits := make([]code.Builder, 0, len(ifstmt.Conds))
-
-	totalLength := t.Len()
-	jmpLength := 9 // OP + Uint
-
-	for i, cond := range ifstmt.Conds {
-		thenTarget := code.NewBuilder() // then
+func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
+	// An if statement is composed out of CondNodes.
+	// A cond node can be either the top `if` branch, and `elif` branch
+	// inbetween, or `else` branch at the bottom.
+	// `if` and `elif` are identical, but the Cond expr of an `else` node is an empty ast.Expr.
+	//
+	// Each compiled CondNode consists of 4 parts.
+	// 1. Condition check => Compiled Cond expr, pushes a bool onto the stack
+	//    Example: `push_false` for `if false {}`
+	// 2. Condition jump => If the condition is false, jump to the **next** CondNode, if not last.
+	//    Example: `jf @elif` or `jf @end`
+	// 3. Then block => Anything the user wants to execute.
+	//    Example: `push_int 1` or something
+	// 4. Then jump => Since the condition was true, we have to jump to the end of the CondNode list,
+	//				   preventing other CondNodes from running. This is missing from the last CondNode.
+	//    Example: `jmp @end`
+
+	subUnits := make([]code.Builder, 0, len(ifStmt.Conds))
+
+	totalLength := 0
+	jmpLength := 9 // The length of either of the jump parts: op: 1 + uint: 8 = 9
+
+	for i, cond := range ifStmt.Conds {
+		// Then block
+		thenTarget := code.NewBuilder()
 		if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil {
 			return err
 		}
 
 		totalLength += thenTarget.Len()
-		if i != len(ifstmt.Conds)-1 {
-			totalLength += jmpLength // thenjmp
+		if i != len(ifStmt.Conds)-1 {
+			totalLength += jmpLength
 		}
 
-		subUnitTarget := code.NewBuilder()
+		// Condition check
+		conditionTarget := code.NewBuilder()
 		if !cond.Cond.IsEmpty() {
-			// cond
-			if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil {
+			if err := comp.compileExpr(&conditionTarget, cond.Cond); err != nil {
 				return err
 			}
 
-			totalLength += subUnitTarget.Len() + jmpLength // condjmp
+			totalLength += conditionTarget.Len() + jmpLength // condjmp
 
-			subUnitTarget.AppendOp(code.OpJf)
-			subUnitTarget.AppendInt(int64(totalLength))
+			conditionTarget.AppendOp(code.OpJf)
+			// Condition jump
+			conditionTarget.AppendReferenceToPc(int64(totalLength))
 		}
 
-		subUnitTarget.AppendRaw(thenTarget.Code())
-
-		subUnits = append(subUnits, subUnitTarget)
+		subUnit := conditionTarget
+		subUnit.AppendBuilder(thenTarget)
+		subUnits = append(subUnits, subUnit)
 	}
 
+	result := code.NewBuilder()
+
+	// Then jumps
 	for i, subUnit := range subUnits {
-		if i != len(ifstmt.Conds)-1 {
+		if i != len(ifStmt.Conds)-1 {
 			subUnit.AppendOp(code.OpJmp)
-			subUnit.AppendInt(int64(totalLength))
+			subUnit.AppendReferenceToPc(int64(totalLength))
 		}
 
-		t.AppendRaw(subUnit.Code())
+		result.AppendBuilderWithoutAdjustingReferences(subUnit)
 	}
 
+	t.AppendBuilder(result)
+
 	return nil
 }
 
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index a6a1ba7..88413c7 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -203,6 +203,37 @@ func TestIfNoElse(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestNestedIfs(t *testing.T) {
+	src := `
+	if true {
+		if false {
+			1
+		} else {
+			2
+		}
+	}
+	`
+
+	expected := `
+	push_true
+	jf @end
+
+	push_false
+	jf @else
+
+	push_int 1
+	jmp @end
+
+	@else:
+	push_int 2
+	
+	@end:
+	halt
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func mustCompileTo(t *testing.T, src, expected string) {
 	scanner := scanner.New(strings.NewReader(src))
 	tokens, err := scanner.Scan()
diff --git a/pkg/lang/vm/code/builder.go b/pkg/lang/vm/code/builder.go
index 4a5ef70..f027c49 100644
--- a/pkg/lang/vm/code/builder.go
+++ b/pkg/lang/vm/code/builder.go
@@ -2,12 +2,14 @@ package code
 
 import (
 	"encoding/binary"
+	"fmt"
 	"math"
 )
 
 type Builder struct {
-	code      []byte
-	debugInfo DebugInfo
+	code           []byte
+	debugInfo      DebugInfo
+	relativePcRefs map[int][]int
 
 	currentLine int
 	lineStart   int
@@ -15,10 +17,11 @@ type Builder struct {
 
 func NewBuilder() Builder {
 	return Builder{
-		code:        make([]byte, 0, 64),
-		debugInfo:   NewDebugInfo("unknown file"),
-		lineStart:   -1,
-		currentLine: -1,
+		code:           make([]byte, 0, 64),
+		debugInfo:      NewDebugInfo("unknown file"),
+		relativePcRefs: make(map[int][]int),
+		lineStart:      -1,
+		currentLine:    -1,
 	}
 }
 
@@ -40,17 +43,59 @@ func (b *Builder) AppendString(s string) {
 	b.code = append(b.code, 0)
 }
 
+func (b *Builder) AppendReferenceToPc(pc int64) {
+	b.addPcRef(int(pc), b.Len())
+	b.AppendInt(pc)
+}
+
 func (b *Builder) AppendRaw(raw Raw) {
 	b.code = append(b.code, raw...)
 }
 
+func (b *Builder) AppendBuilder(other Builder) {
+	code := other.code
+	for pc, refsToPc := range other.relativePcRefs {
+		newPc := b.Len() + pc
+		for _, ref := range refsToPc {
+			valueAtRef := binary.LittleEndian.Uint64(code[ref : ref+8])
+			if int(valueAtRef) != pc {
+				panic(fmt.Errorf("reference to pc in builder does not actually reference pc. (pc: %d, value at reference: %d)", pc, valueAtRef))
+			}
+
+			binary.LittleEndian.PutUint64(code[ref:], uint64(newPc))
+			b.addPcRef(newPc, ref+b.Len())
+		}
+	}
+
+	b.code = append(b.code, code...)
+
+	if other.debugInfo.pcToLine.Len() != 0 || other.currentLine != -1 || other.lineStart != -1 {
+		panic("appending debug infos not implemented yet")
+	}
+}
+
+func (b *Builder) AppendBuilderWithoutAdjustingReferences(other Builder) {
+	code := other.code
+	for pc, refsToPc := range other.relativePcRefs {
+		for _, ref := range refsToPc {
+			b.addPcRef(pc, ref+b.Len())
+		}
+	}
+
+	b.code = append(b.code, code...)
+
+	if other.debugInfo.pcToLine.Len() != 0 || other.currentLine != -1 || other.lineStart != -1 {
+		panic("appending debug infos not implemented yet")
+	}
+}
+
 func (b *Builder) StartLine(line int) {
 	if b.lineStart != -1 {
 		panic("line already started")
 	}
 
 	b.currentLine = line
-	b.lineStart = len(b.code)
+	b.lineStart = b.Len()
 }
 
 func (b *Builder) EndLine() {
@@ -63,11 +108,11 @@ func (b *Builder) EndLine() {
 		panic("line not started")
 	}
 
-	if b.lineStart == len(b.code) {
+	if b.lineStart == b.Len() {
 		return
 	}
-	
-	b.debugInfo.AppendLine(len(b.code)-1, b.currentLine)
+
+	b.debugInfo.AppendLine(b.Len()-1, b.currentLine)
 }
 
 func (b *Builder) SetInt(at int, x int64) {
@@ -85,3 +130,15 @@ func (b *Builder) Len() int {
 func (b *Builder) Build() Code {
 	return New(b.code, b.debugInfo)
 }
+
+func (b *Builder) addPcRef(pc int, at int) {
+	refs, ok := b.relativePcRefs[pc]
+	if ok {
+		refs = append(refs, at)
+	} else {
+		refs = make([]int, 1)
+		refs[0] = at
+	}
+
+	b.relativePcRefs[pc] = refs
+}
diff --git a/pkg/libs/rangemap/rangemap.go b/pkg/libs/rangemap/rangemap.go
index df891a6..39f97ea 100644
--- a/pkg/libs/rangemap/rangemap.go
+++ b/pkg/libs/rangemap/rangemap.go
@@ -51,6 +51,9 @@ func (rm *RangeMap[D]) Get(i int) (*D, bool) {
 
 	return nil, false
 }
+func (rm *RangeMap[D]) Len() int {
+	return len(rm.ranges)
+}
 
 type rangeEntry struct {
 	from int