about summary refs log tree commit diff
path: root/pkg/lang
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-07-19 02:27:19 +0200
committerMel <einebeere@gmail.com>2022-07-19 02:27:19 +0200
commitb6fa4bc82398b09307f2e6b75e27422d1d1ecb33 (patch)
treee5b7aec7eb7f72f1c2f55e4b2a78d331bd81485e /pkg/lang
parente06aeb7fa2fcb9046b8861ed3c23417555e823f5 (diff)
downloadjinx-b6fa4bc82398b09307f2e6b75e27422d1d1ecb33.tar.zst
jinx-b6fa4bc82398b09307f2e6b75e27422d1d1ecb33.zip
Implement stack hygiene
Diffstat (limited to 'pkg/lang')
-rw-r--r--pkg/lang/compiler/compiler.go53
-rw-r--r--pkg/lang/compiler/compiler_test.go20
-rw-r--r--pkg/lang/compiler/scope/scope_chain.go24
-rw-r--r--pkg/lang/vm/exec.go10
-rw-r--r--pkg/lang/vm/text/decompiler.go2
-rw-r--r--pkg/lang/vm/vm.go5
-rw-r--r--pkg/lang/vm/vm_test.go5
7 files changed, 102 insertions, 17 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index 4167088..e446dd6 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -96,8 +96,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	case ast.StmtKindThrow:
 		panic("throw statements not implemented")
 	case ast.StmtKindExpr:
-		expr := stmt.Value.(ast.StmtExpr).Value
-		err = comp.compileExpr(t, expr)
+		exprStmt := stmt.Value.(ast.StmtExpr)
+		err = comp.compileExprStmt(t, exprStmt)
 	default:
 		panic(fmt.Errorf("unknown statement kind: %d", stmt.Kind))
 	}
@@ -139,7 +139,8 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 		return err
 	}
 
-	comp.scopes.Exit()
+	// Function declaration scopes do not pollute stack
+	_ = comp.scopes.Exit()
 
 	comp.funcs = append(comp.funcs, &functionTarget)
 
@@ -175,6 +176,8 @@ func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
 	//				   preventing other CondNodes from running. This is missing from the last CondNode.
 	//    Example: `jmp @end`
 
+	comp.scopes.Enter()
+
 	// First we create all the markers we'll need for the if statement
 	parentMarker := comp.scopes.CreateAnonymousFunctionSubUnit()
 
@@ -226,6 +229,8 @@ func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
 
 	t.PutMarker(endMarker)
 
+	comp.exitScopeAndCleanStack(t)
+
 	return nil
 }
 
@@ -251,18 +256,24 @@ func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtFo
 		t.AppendMarkerReference(endMarker)
 	}
 
+	// Inner scope, dropped on every iteration
+	comp.scopes.Enter()
+
 	// Do block
 	if err := comp.compileBlockNode(t, forCondStmt.Do); err != nil {
 		return err
 	}
 
+	// Drop inner scope
+	comp.exitScopeAndCleanStack(t)
+
 	// Repeat jump
 	t.AppendOp(code.OpJmp)
 	t.AppendMarkerReference(repeatMarker)
 
 	t.PutMarker(endMarker)
 
-	comp.scopes.Exit()
+	comp.exitScopeAndCleanStack(t)
 
 	return nil
 }
@@ -302,6 +313,7 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 	// @end:
 	// halt
 
+	// Upper scope houses all internal loop locals, which are dropped when the loop ends.
 	endMarker, repeatMarker := comp.scopes.EnterLoop()
 
 	// Preparation
@@ -364,18 +376,25 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 	t.AppendOp(code.OpSetLocal)
 	t.AppendInt(int64(iLocal))
 
+	// Inner scope, dropped every loop iteration.
+	comp.scopes.Enter()
+
 	// Do block
 	if err := comp.compileBlockNode(t, forInStmt.Do); err != nil {
 		return err
 	}
 
+	// Drop inner scope
+	comp.exitScopeAndCleanStack(t)
+
 	// Repeat jump
 	t.AppendOp(code.OpJmp)
 	t.AppendMarkerReference(repeatMarker)
 
 	t.PutMarker(endMarker)
 
-	comp.scopes.Exit()
+	// Drop upper scope
+	comp.exitScopeAndCleanStack(t)
 
 	return nil
 }
@@ -424,6 +443,23 @@ func (comp *Compiler) compileBreakStmt(t *code.Builder, breakStmt ast.StmtBreak)
 	return nil
 }
 
+func (comp *Compiler) compileExprStmt(t *code.Builder, exprStmt ast.StmtExpr) error {
+	if err := comp.compileExpr(t, exprStmt.Value); err != nil {
+		return err
+	}
+
+	isAssignment := exprStmt.Value.Kind == ast.ExprKindBinary &&
+		exprStmt.Value.Value.(ast.ExprBinary).Op == ast.BinOpAssign
+
+	// If the expression is not assignment, we need to drop the junk value.
+	if !isAssignment {
+		t.AppendOp(code.OpDrop)
+		t.AppendInt(1)
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
@@ -676,3 +712,10 @@ func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) err
 
 	return nil
 }
+
+func (comp *Compiler) exitScopeAndCleanStack(t *code.Builder) {
+	if stackSpace := comp.scopes.Exit(); stackSpace != 0 {
+		t.AppendOp(code.OpDrop)
+		t.AppendInt(int64(stackSpace))
+	}
+}
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index 7741ca9..b8a6264 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -21,6 +21,7 @@ func TestSimpleAddExpr(t *testing.T) {
 	push_int 1
 	push_int 2
 	add
+	drop 1
 	halt
 	`
 
@@ -44,6 +45,7 @@ func TestOperationOrder(t *testing.T) {
 	sub
 	push_int 4
 	add
+	drop 1
 	halt
 	`
 
@@ -78,6 +80,8 @@ func TestNestedExpr(t *testing.T) {
 	add
 	sub
 
+	drop 1
+
 	halt
 	`
 
@@ -141,15 +145,12 @@ func TestArrayLit(t *testing.T) {
 
 func TestIf(t *testing.T) {
 	src := `
-	":("
 	if 1 <= 5 {
 		"hello " + "world"
 	}
 	`
 
 	expected := `
-	push_string ":("
-
 	push_int 1
 	push_int 5
 	lte
@@ -160,6 +161,8 @@ func TestIf(t *testing.T) {
 	push_string "world"
 	add
 
+	drop 1
+
 	@end:
 	halt
 	`
@@ -183,16 +186,19 @@ func TestIfElifElse(t *testing.T) {
 	jf @elif
 
 	push_int 1
+	drop 1
 	jmp @end
 
 	@elif:
 	push_true
 	jf @else
 	push_int 2
+	drop 1
 	jmp @end
 
 	@else:
 	push_int 3
+	drop 1
 
 	@end:
 	halt
@@ -215,12 +221,14 @@ func TestIfNoElse(t *testing.T) {
 	jf @elif
 
 	push_int 1
+	drop 1
 	jmp @end
 
 	@elif:
 	push_true
 	jf @end
 	push_int 2
+	drop 1
 
 	@end:
 	halt
@@ -248,10 +256,12 @@ func TestNestedIfs(t *testing.T) {
 	jf @else
 
 	push_int 1
+	drop 1
 	jmp @end
 
 	@else:
 	push_int 2
+	drop 1
 	
 	@end:
 	halt
@@ -279,6 +289,7 @@ func TestForCond(t *testing.T) {
 	jf @inner_end
 
 	push_int 1
+	drop 1
 	jmp @inner_start
 
 	@inner_end:
@@ -410,9 +421,11 @@ func TestForIn(t *testing.T) {
 	push_string "say"
 	get_local 2
 	call 1
+	drop 1
 
 	jmp @check
 	@end:
+	drop 3
 	halt
 	`
 
@@ -456,6 +469,7 @@ func TestFunctionArgs(t *testing.T) {
 	push_int 4
 	push_int 5
 	call 2
+	drop 1
 	halt
 
 	@add:
diff --git a/pkg/lang/compiler/scope/scope_chain.go b/pkg/lang/compiler/scope/scope_chain.go
index 4d5bbcf..d0108ee 100644
--- a/pkg/lang/compiler/scope/scope_chain.go
+++ b/pkg/lang/compiler/scope/scope_chain.go
@@ -84,13 +84,29 @@ func (sc *ScopeChain) EnterLoop() (code.Marker, code.Marker) {
 	return breakMarker, continueMarker
 }
 
-func (sc *ScopeChain) Exit() {
+func (sc *ScopeChain) Exit() int {
 	if sc.CurrentScopeID() == 0 {
-		return
+		return 0
 	}
+	id := sc.CurrentScopeID()
+	stackSpace := len(sc.symbolScopes[id].variableSymbols)
 
-	sc.symbolScopes[sc.CurrentScopeID()] = SymbolScope{}
-	sc.symbolScopes = sc.symbolScopes[:sc.CurrentScopeID()]
+	sc.symbolScopes[id] = SymbolScope{}
+	sc.symbolScopes = sc.symbolScopes[:id]
+
+	if sc.CurrentLoop() != nil && sc.CurrentLoop().id == id {
+		lastLoopScope := len(sc.loopScopes) - 1
+		sc.loopScopes[lastLoopScope] = LoopScope{}
+		sc.loopScopes = sc.loopScopes[:lastLoopScope]
+	}
+
+	if sc.CurrentFunction().id == id {
+		lastFunctionScope := len(sc.functionScopes) - 1
+		sc.functionScopes[lastFunctionScope] = FunctionScope{}
+		sc.functionScopes = sc.functionScopes[:lastFunctionScope]
+	}
+
+	return stackSpace
 }
 
 func (sc *ScopeChain) Declare(name string) (int, bool) {
diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go
index f92e486..3a1ce36 100644
--- a/pkg/lang/vm/exec.go
+++ b/pkg/lang/vm/exec.go
@@ -97,6 +97,16 @@ func (vm *VM) execPushType(name string) error {
 	return nil
 }
 
+func (vm *VM) execDrop(dropAmount uint) error {
+	for i := 0; i < int(dropAmount); i++ {
+		if _, err := vm.popAndDrop(); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
 func (vm *VM) execGetMember(name string) error {
 	parent, err := vm.stack.Pop()
 	if err != nil {
diff --git a/pkg/lang/vm/text/decompiler.go b/pkg/lang/vm/text/decompiler.go
index b3c8bb4..2b4e704 100644
--- a/pkg/lang/vm/text/decompiler.go
+++ b/pkg/lang/vm/text/decompiler.go
@@ -55,7 +55,6 @@ func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) {
 		code.OpPushNull,
 		code.OpPushArray,
 		code.OpPushObject,
-		code.OpDrop,
 		code.OpAnchorType,
 		code.OpAdd,
 		code.OpSub,
@@ -75,6 +74,7 @@ func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) {
 
 	// Operations that take an int.
 	case code.OpPushInt,
+		code.OpDrop,
 		code.OpGetLocal,
 		code.OpSetLocal,
 		code.OpGetEnv,
diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go
index ff9c28e..8b47915 100644
--- a/pkg/lang/vm/vm.go
+++ b/pkg/lang/vm/vm.go
@@ -116,7 +116,10 @@ func (vm *VM) step(op code.Op) (stepDecision, error) {
 		err = vm.execPushType(name)
 
 	case code.OpDrop:
-		_, err = vm.stack.Pop()
+		dropAmount, advance := vm.code.GetUint(vm.pc)
+		vm.pc += advance
+
+		err = vm.execDrop(uint(dropAmount))
 
 	case code.OpGetGlobal:
 		panic("not implemented")
diff --git a/pkg/lang/vm/vm_test.go b/pkg/lang/vm/vm_test.go
index 7fe07dd..7e749a4 100644
--- a/pkg/lang/vm/vm_test.go
+++ b/pkg/lang/vm/vm_test.go
@@ -75,7 +75,7 @@ func TestFibonacci(t *testing.T) {
 	temp_arr_push
 
 	# Drop local 1, which was the length of the array, which we no longer need
-	drop
+	drop 1
 
 	get_local 0
 	temp_arr_len
@@ -391,8 +391,7 @@ func TestPrimes(t *testing.T) {
 		add
 		set_local 0
 
-		drop
-		drop
+		drop 2
 
 		jmp @main_loop
 	@end: