about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lang/compiler/compiler.go41
-rw-r--r--pkg/lang/compiler/compiler_test.go37
-rw-r--r--pkg/lang/compiler/scope/scope_chain.go60
-rw-r--r--pkg/lang/compiler/scope/scopes.go8
-rw-r--r--pkg/lang/compiler/scope/symbol.go23
5 files changed, 147 insertions, 22 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index 0424862..5e2a645 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -112,8 +112,8 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 		return err
 	}
 
-	// Function declaration scopes do not pollute stack
-	_ = comp.scopes.Exit()
+	fnScope := comp.scopes.CurrentFunction()
+	_ = comp.scopes.Exit() // Function declaration scopes do not pollute stack
 
 	comp.funcs = append(comp.funcs, &functionTarget)
 
@@ -127,7 +127,15 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 		t.AppendInt(int64(len(fnDeclStmt.Args)))
 	}
 
-	// TODO: Attach the function environment
+	for _, outsideSymbol := range fnScope.OutsideSymbolsInEnv() {
+		// TODO: Implement env inheritance
+		if comp.scopes.CurrentFunction().ID() != outsideSymbol.ScopeID() {
+			panic("env inheritance not implemented")
+		}
+
+		t.AppendOp(code.OpAddToEnv)
+		t.AppendInt(int64(outsideSymbol.IndexInScope()))
+	}
 
 	return nil
 }
@@ -537,18 +545,24 @@ func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) er
 		return fmt.Errorf("variable %s not declared", name)
 	}
 
-	if symbolId.SymbolKind() != scope.SymbolKindVariable {
-		return fmt.Errorf("can't assign to a %v", symbolId.SymbolKind())
-	}
-
-	symbol := comp.scopes.GetVariable(symbolId)
-
 	if err := comp.compileExpr(t, expr.Right); err != nil {
 		return err
 	}
 
-	t.AppendOp(code.OpSetLocal)
-	t.AppendInt(int64(symbol.Data().LocalIndex()))
+	switch symbolId.SymbolKind() {
+	case scope.SymbolKindVariable:
+		symbol := comp.scopes.GetVariable(symbolId)
+
+		t.AppendOp(code.OpSetLocal)
+		t.AppendInt(int64(symbol.Data().LocalIndex()))
+	case scope.SymbolKindEnv:
+		symbol := comp.scopes.GetEnv(symbolId)
+
+		t.AppendOp(code.OpSetEnv)
+		t.AppendInt(int64(symbol.Data().IndexInEnv()))
+	default:
+		panic("unknown symbol kind")
+	}
 
 	return nil
 }
@@ -639,6 +653,11 @@ func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) erro
 
 		t.AppendOp(code.OpGetLocal)
 		t.AppendInt(int64(symbol.Data().LocalIndex()))
+	case scope.SymbolKindEnv:
+		symbol := comp.scopes.GetEnv(symbolId)
+
+		t.AppendOp(code.OpGetEnv)
+		t.AppendInt(int64(symbol.Data().IndexInEnv()))
 	default:
 		panic(fmt.Errorf("unknown symbol kind: %v", symbolId.SymbolKind()))
 	}
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index 07586a6..9c21a3b 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -486,6 +486,43 @@ func TestFunctionArgs(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestClosureEnv(t *testing.T) {
+	src := `
+	fn create() {
+		var x = 0
+		fn closure() {
+			x = x + 1
+			return x
+		}
+
+		return closure
+	}
+	`
+
+	expected := `
+	push_function @create
+	halt
+
+	@closure:
+	get_env 0
+	push_int 1
+	add
+	set_env 0
+	get_env 0
+	ret
+
+	@create:
+	push_int 0
+	push_function @closure
+	add_to_env 0
+
+	get_local 1
+	ret
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func mustCompileTo(t *testing.T, src, expected string) {
 	scanner := scanner.New(strings.NewReader(src))
 	tokens, err := scanner.Scan()
diff --git a/pkg/lang/compiler/scope/scope_chain.go b/pkg/lang/compiler/scope/scope_chain.go
index f4b9f46..1b83c75 100644
--- a/pkg/lang/compiler/scope/scope_chain.go
+++ b/pkg/lang/compiler/scope/scope_chain.go
@@ -91,19 +91,14 @@ func (sc *ScopeChain) Exit() int {
 	id := sc.CurrentScopeID()
 	stackSpace := len(sc.symbolScopes[id].variableSymbols)
 
-	sc.symbolScopes[id] = SymbolScope{}
 	sc.symbolScopes = sc.symbolScopes[:id]
 
 	if sc.CurrentLoop() != nil && sc.CurrentLoop().id == id {
-		lastLoopScope := len(sc.loopScopes) - 1
-		sc.loopScopes[lastLoopScope] = LoopScope{}
-		sc.loopScopes = sc.loopScopes[:lastLoopScope]
+		sc.loopScopes = sc.loopScopes[:len(sc.loopScopes)-1]
 	}
 
 	if sc.CurrentFunction().id == id {
-		lastFunctionScope := len(sc.functionScopes) - 1
-		sc.functionScopes[lastFunctionScope] = FunctionScope{}
-		sc.functionScopes = sc.functionScopes[:lastFunctionScope]
+		sc.functionScopes = sc.functionScopes[:len(sc.functionScopes)-1]
 	}
 
 	return stackSpace
@@ -178,11 +173,23 @@ func (sc *ScopeChain) CreateFunctionSubUnit(subUnitName string) code.Marker {
 }
 
 func (sc *ScopeChain) Lookup(name string) (SymbolID, bool) {
-	if id, ok := sc.nameToSymbol[name]; ok {
-		return id, true
+	id, ok := sc.nameToSymbol[name]
+	if !ok {
+		return SymbolID{}, false
 	}
 
-	return SymbolID{}, false
+	// Check whether the symbol is outside the current function scope.
+	fnScope := sc.CurrentFunction()
+	if id.scopeID < fnScope.id {
+		// Return env symbol instead of a local symbol.
+		return SymbolID{
+			symbolKind:   SymbolKindEnv,
+			scopeID:      id.scopeID,
+			indexInScope: id.indexInScope,
+		}, true
+	}
+
+	return id, true
 }
 
 func (sc *ScopeChain) GetVariable(id SymbolID) Symbol[SymbolVariable] {
@@ -192,3 +199,36 @@ func (sc *ScopeChain) GetVariable(id SymbolID) Symbol[SymbolVariable] {
 
 	return sc.symbolScopes[id.scopeID].variableSymbols[id.indexInScope]
 }
+
+func (sc *ScopeChain) GetEnv(id SymbolID) Symbol[SymbolEnv] {
+	if id.symbolKind != SymbolKindEnv {
+		panic("incorrect symbol id kind given")
+	}
+
+	symbol := sc.symbolScopes[id.scopeID].variableSymbols[id.indexInScope]
+
+	// Add the local to the function scope, if it is not already there.
+	fnScope := sc.CurrentFunction()
+
+	var indexInEnv int
+	alreadyUsed := false
+	for i, env := range fnScope.outsideSymbolsInEnv {
+		if env == id {
+			alreadyUsed = true
+			indexInEnv = i
+			break
+		}
+	}
+
+	if !alreadyUsed {
+		indexInEnv = len(fnScope.outsideSymbolsInEnv)
+		fnScope.outsideSymbolsInEnv = append(fnScope.outsideSymbolsInEnv, id)
+	}
+
+	return Symbol[SymbolEnv]{
+		name: symbol.name,
+		data: SymbolEnv{
+			indexInEnv: indexInEnv,
+		},
+	}
+}
diff --git a/pkg/lang/compiler/scope/scopes.go b/pkg/lang/compiler/scope/scopes.go
index e34b45a..7a1b20c 100644
--- a/pkg/lang/compiler/scope/scopes.go
+++ b/pkg/lang/compiler/scope/scopes.go
@@ -26,6 +26,8 @@ type FunctionScope struct {
 	id           ScopeID
 	unit         code.Marker
 	subUnitCount int
+
+	outsideSymbolsInEnv []SymbolID
 }
 
 func NewFunctionScope(id ScopeID, unit code.Marker) FunctionScope {
@@ -33,6 +35,8 @@ func NewFunctionScope(id ScopeID, unit code.Marker) FunctionScope {
 		id:           id,
 		unit:         unit,
 		subUnitCount: 0,
+
+		outsideSymbolsInEnv: make([]SymbolID, 0),
 	}
 }
 
@@ -44,6 +48,10 @@ func (sf FunctionScope) Unit() code.Marker {
 	return sf.unit
 }
 
+func (sf FunctionScope) OutsideSymbolsInEnv() []SymbolID {
+	return sf.outsideSymbolsInEnv
+}
+
 func (sf FunctionScope) IsRootScope() bool {
 	return sf.ID() == ScopeID(0)
 }
diff --git a/pkg/lang/compiler/scope/symbol.go b/pkg/lang/compiler/scope/symbol.go
index 3b50108..b87d5aa 100644
--- a/pkg/lang/compiler/scope/symbol.go
+++ b/pkg/lang/compiler/scope/symbol.go
@@ -10,17 +10,30 @@ func (id SymbolID) SymbolKind() SymbolKind {
 	return id.symbolKind
 }
 
+func (id SymbolID) ScopeID() ScopeID {
+	return id.scopeID
+}
+
+func (id SymbolID) IndexInScope() int {
+	return id.indexInScope
+}
+
 type SymbolKind int
 
 const (
 	// A variable symbol is bound to a local on the stack.
 	SymbolKindVariable SymbolKind = iota
+	// An env symbol is bound to a local on the stack, outside of the function's scope.
+	// Emitted at lookup time, so the SymbolScope has no array for them.
+	SymbolKindEnv SymbolKind = iota
 )
 
 func (s SymbolKind) String() string {
 	switch s {
 	case SymbolKindVariable:
 		return "variable"
+	case SymbolKindEnv:
+		return "env"
 	default:
 		panic("unknown symbol kind")
 	}
@@ -36,7 +49,7 @@ func (s Symbol[D]) Data() D {
 }
 
 type SymbolData interface {
-	SymbolVariable
+	SymbolVariable | SymbolEnv
 }
 
 type SymbolVariable struct {
@@ -46,3 +59,11 @@ type SymbolVariable struct {
 func (sv SymbolVariable) LocalIndex() int {
 	return sv.localIndex
 }
+
+type SymbolEnv struct {
+	indexInEnv int
+}
+
+func (se SymbolEnv) IndexInEnv() int {
+	return se.indexInEnv
+}