about summary refs log tree commit diff
path: root/pkg/lang/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lang/compiler')
-rw-r--r--pkg/lang/compiler/compiler.go52
-rw-r--r--pkg/lang/compiler/compiler_test.go71
-rw-r--r--pkg/lang/compiler/scope/scope_chain.go24
-rw-r--r--pkg/lang/compiler/scope/scopes.go27
4 files changed, 159 insertions, 15 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index ad37ded..4167088 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -88,9 +88,11 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 		returnStmt := stmt.Value.(ast.StmtReturn)
 		err = comp.compileReturnStmt(t, returnStmt)
 	case ast.StmtKindContinue:
-		panic("continue statements not implemented")
+		continueStmt := stmt.Value.(ast.StmtContinue)
+		err = comp.compileContinueStmt(t, continueStmt)
 	case ast.StmtKindBreak:
-		panic("break statements not implemented")
+		breakStmt := stmt.Value.(ast.StmtBreak)
+		err = comp.compileBreakStmt(t, breakStmt)
 	case ast.StmtKindThrow:
 		panic("throw statements not implemented")
 	case ast.StmtKindExpr:
@@ -234,12 +236,9 @@ func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtFo
 	// 3. Do block: Does something
 	// 4. Repeat jump: Jumps back to start
 
-	parentMarker := comp.scopes.CreateAnonymousFunctionSubUnit()
-
-	startMarker := parentMarker.SubMarker("start")
-	endMarker := parentMarker.SubMarker("end")
+	endMarker, repeatMarker := comp.scopes.EnterLoop()
 
-	t.PutMarker(startMarker)
+	t.PutMarker(repeatMarker)
 
 	if !forCondStmt.Cond.IsEmpty() {
 		// Condition check
@@ -259,10 +258,12 @@ func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtFo
 
 	// Repeat jump
 	t.AppendOp(code.OpJmp)
-	t.AppendMarkerReference(startMarker)
+	t.AppendMarkerReference(repeatMarker)
 
 	t.PutMarker(endMarker)
 
+	comp.scopes.Exit()
+
 	return nil
 }
 
@@ -301,10 +302,7 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 	// @end:
 	// halt
 
-	parentMarker := comp.scopes.CreateAnonymousFunctionSubUnit()
-
-	checkMarker := parentMarker.SubUnit("check")
-	endMarker := parentMarker.SubUnit("end")
+	endMarker, repeatMarker := comp.scopes.EnterLoop()
 
 	// Preparation
 	if err := comp.compileExpr(t, forInStmt.Collection); err != nil {
@@ -323,7 +321,7 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 	}
 
 	// Condition check
-	t.PutMarker(checkMarker)
+	t.PutMarker(repeatMarker)
 
 	t.AppendOp(code.OpGetLocal)
 	t.AppendInt(int64(iLocal))
@@ -373,10 +371,12 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 
 	// Repeat jump
 	t.AppendOp(code.OpJmp)
-	t.AppendMarkerReference(checkMarker)
+	t.AppendMarkerReference(repeatMarker)
 
 	t.PutMarker(endMarker)
 
+	comp.scopes.Exit()
+
 	return nil
 }
 
@@ -400,6 +400,30 @@ func (comp *Compiler) compileReturnStmt(t *code.Builder, returnStmt ast.StmtRetu
 	return nil
 }
 
+func (comp *Compiler) compileContinueStmt(t *code.Builder, continueStmt ast.StmtContinue) error {
+	loopScope := comp.scopes.CurrentLoop()
+	if loopScope == nil {
+		return fmt.Errorf("can't continue when not inside a loop")
+	}
+
+	t.AppendOp(code.OpJmp)
+	t.AppendMarkerReference(loopScope.ContinueMarker())
+
+	return nil
+}
+
+func (comp *Compiler) compileBreakStmt(t *code.Builder, breakStmt ast.StmtBreak) error {
+	loopScope := comp.scopes.CurrentLoop()
+	if loopScope == nil {
+		return fmt.Errorf("can't break when not inside a loop")
+	}
+
+	t.AppendOp(code.OpJmp)
+	t.AppendMarkerReference(loopScope.BreakMarker())
+
+	return nil
+}
+
 func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index bdbc375..7741ca9 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -291,9 +291,65 @@ func TestForCond(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestForCondBreakContinue(t *testing.T) {
+	src := `
+	var sum = 0
+	var x = 0
+	for {
+		if x % 2 == 0 {
+			continue
+		}
+
+		sum = sum + x
+
+		if x == 100 {
+			break
+		}
+	}
+	`
+
+	expected := `
+	push_int 0
+	push_int 0
+
+	@continue:
+	get_local 1
+	push_int 2
+	mod
+	push_int 0
+	eq
+	jf @sum
+	jmp @continue
+
+	@sum:
+	get_local 0
+	get_local 1
+	add
+	set_local 0
+
+	get_local 1
+	push_int 100
+	eq
+	jf @repeat
+	jmp @break
+
+	@repeat:
+	jmp @continue
+
+	@break:
+	halt
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func TestForIn(t *testing.T) {
 	src := `
-	for x in [1, 2, 3] {
+	for x in [1, 2, 3, "oops"] {
+		if x == "oops" {
+			break
+		}
+
 		"say"(x)
 	}
 	`
@@ -316,6 +372,11 @@ func TestForIn(t *testing.T) {
 	push_int 3
 	call 1
 
+	get_local 0
+	get_member "push"
+	push_string "oops"
+	call 1
+
 	push_int 0
 	push_null
 
@@ -338,6 +399,14 @@ func TestForIn(t *testing.T) {
 	add
 	set_local 1
 
+	get_local 2
+	push_string "oops"
+	eq
+	jf @say
+
+	jmp @end
+
+	@say:
 	push_string "say"
 	get_local 2
 	call 1
diff --git a/pkg/lang/compiler/scope/scope_chain.go b/pkg/lang/compiler/scope/scope_chain.go
index 683d53a..4d5bbcf 100644
--- a/pkg/lang/compiler/scope/scope_chain.go
+++ b/pkg/lang/compiler/scope/scope_chain.go
@@ -10,6 +10,7 @@ type ScopeChain struct {
 
 	symbolScopes   []SymbolScope // All other scopes are bound to this by ID.
 	functionScopes []FunctionScope
+	loopScopes     []LoopScope
 }
 
 func NewScopeChain() ScopeChain {
@@ -19,11 +20,14 @@ func NewScopeChain() ScopeChain {
 	functionScopes := make([]FunctionScope, 1)
 	functionScopes[0] = NewFunctionScope(0, "") // Root function to house top-scope sub units
 
+	loopScopes := make([]LoopScope, 0)
+
 	return ScopeChain{
 		nameToSymbol: make(map[string]SymbolID),
 
 		symbolScopes:   symbolScopes,
 		functionScopes: functionScopes,
+		loopScopes:     loopScopes,
 	}
 }
 
@@ -51,6 +55,14 @@ func (sc *ScopeChain) CurrentFunction() *FunctionScope {
 	return &sc.functionScopes[len(sc.functionScopes)-1]
 }
 
+func (sc *ScopeChain) CurrentLoop() *LoopScope {
+	if len(sc.loopScopes) == 0 {
+		return nil
+	}
+
+	return &sc.loopScopes[len(sc.loopScopes)-1]
+}
+
 func (sc *ScopeChain) Enter() {
 	sc.symbolScopes = append(sc.symbolScopes, NewSymbolScope())
 }
@@ -60,6 +72,18 @@ func (sc *ScopeChain) EnterFunction(unit code.Marker) {
 	sc.functionScopes = append(sc.functionScopes, NewFunctionScope(sc.CurrentScopeID(), unit))
 }
 
+func (sc *ScopeChain) EnterLoop() (code.Marker, code.Marker) {
+	parentMarker := sc.CreateAnonymousFunctionSubUnit()
+
+	breakMarker := parentMarker.SubMarker("end")
+	continueMarker := parentMarker.SubMarker("start")
+
+	sc.Enter()
+	sc.loopScopes = append(sc.loopScopes, NewLoopScope(sc.CurrentScopeID(), breakMarker, continueMarker))
+
+	return breakMarker, continueMarker
+}
+
 func (sc *ScopeChain) Exit() {
 	if sc.CurrentScopeID() == 0 {
 		return
diff --git a/pkg/lang/compiler/scope/scopes.go b/pkg/lang/compiler/scope/scopes.go
index 39e48ef..5cfcd5d 100644
--- a/pkg/lang/compiler/scope/scopes.go
+++ b/pkg/lang/compiler/scope/scopes.go
@@ -9,6 +9,7 @@ type ScopeKind int
 const (
 	ScopeKindNormal ScopeKind = iota
 	ScopeKindFunction
+	ScopeKindLoop
 )
 
 type SymbolScope struct {
@@ -48,3 +49,29 @@ func (sf FunctionScope) Unit() code.Marker {
 func (sf FunctionScope) IsRootScope() bool {
 	return sf.ID() == ScopeID(0)
 }
+
+type LoopScope struct {
+	id             ScopeID
+	breakMarker    code.Marker
+	continueMarker code.Marker
+}
+
+func NewLoopScope(id ScopeID, breakMarker code.Marker, continueMarker code.Marker) LoopScope {
+	return LoopScope{
+		id:             id,
+		breakMarker:    breakMarker,
+		continueMarker: continueMarker,
+	}
+}
+
+func (sl LoopScope) ID() ScopeID {
+	return sl.id
+}
+
+func (sl LoopScope) BreakMarker() code.Marker {
+	return sl.breakMarker
+}
+
+func (sl LoopScope) ContinueMarker() code.Marker {
+	return sl.continueMarker
+}