about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-07-11 03:00:21 +0200
committerMel <einebeere@gmail.com>2022-07-11 03:00:21 +0200
commit7a60e91c8c45663127cfa3aa31a14c930f717aeb (patch)
tree245101bb6acdfdbfedb2ceea11759d5a7a0c43bd /pkg
parent0a6339f5e2008a29df1b03ca012e69bd1dfd46cc (diff)
downloadjinx-7a60e91c8c45663127cfa3aa31a14c930f717aeb.tar.zst
jinx-7a60e91c8c45663127cfa3aa31a14c930f717aeb.zip
Function compilation and minor fixes
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lang/compiler/compiler.go114
-rw-r--r--pkg/lang/compiler/compiler_test.go48
-rw-r--r--pkg/lang/compiler/scope_chain.go37
-rw-r--r--pkg/lang/vm/code/builder.go2
4 files changed, 178 insertions, 23 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index a31ade7..ae2a71c 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -9,18 +9,29 @@ import (
 type Compiler struct {
 	ast ast.Program
 
+	funcs []*code.Builder
+
 	scopes ScopeChain
 }
 
 func New(ast ast.Program) *Compiler {
 	return &Compiler{
-		ast: ast,
-
+		ast:    ast,
+		funcs:  make([]*code.Builder, 0),
 		scopes: NewScopeChain(),
 	}
 }
 
 func (comp *Compiler) Compile() (code.Code, error) {
+	// Pre-declare all top-level functions
+	for _, stmt := range comp.ast.Stmts {
+		if stmt.Kind == ast.StmtKindFnDecl {
+			if err := comp.preDeclareFunction(stmt.Value.(ast.StmtFnDecl)); err != nil {
+				return code.Code{}, err
+			}
+		}
+	}
+
 	target := code.NewBuilder()
 
 	for _, stmt := range comp.ast.Stmts {
@@ -31,9 +42,21 @@ func (comp *Compiler) Compile() (code.Code, error) {
 
 	target.AppendOp(code.OpHalt)
 
+	for _, function := range comp.funcs {
+		target.AppendBuilder(*function)
+	}
+
 	return target.Build()
 }
 
+func (comp *Compiler) preDeclareFunction(fnDeclStmt ast.StmtFnDecl) error {
+	if _, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value); !ok {
+		return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value)
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	var err error
 	switch stmt.Kind {
@@ -42,7 +65,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	case ast.StmtKindUse:
 		panic("use statements not implemented")
 	case ast.StmtKindFnDecl:
-		panic("function declaration statements not implemented")
+		fnDeclStmt := stmt.Value.(ast.StmtFnDecl)
+		err = comp.compileFnDeclStmt(t, fnDeclStmt)
 	case ast.StmtKindObjectDecl:
 		panic("object declaration statements not implemented")
 	case ast.StmtKindVarDecl:
@@ -60,7 +84,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	case ast.StmtKindTry:
 		panic("try statements not implemented")
 	case ast.StmtKindReturn:
-		panic("return statements not implemented")
+		returnStmt := stmt.Value.(ast.StmtReturn)
+		err = comp.compileReturnStmt(t, returnStmt)
 	case ast.StmtKindContinue:
 		panic("continue statements not implemented")
 	case ast.StmtKindBreak:
@@ -77,6 +102,47 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	return err
 }
 
+func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDecl) error {
+	marker, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value)
+	if !ok {
+		// If we are in the root scope, the function was simply predeclared :)
+		if comp.scopes.IsRootScope() {
+			symbolID, ok := comp.scopes.Lookup(fnDeclStmt.Name.Value)
+			if !ok {
+				panic("function said it was declared but apparently it was lying")
+			}
+
+			symbol := comp.scopes.GetFunction(symbolID)
+
+			marker = symbol.data.marker
+		} else {
+			return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value)
+		}
+	}
+
+	functionTarget := code.NewBuilder()
+
+	functionTarget.PutMarker(marker)
+
+	comp.scopes.EnterFunction(marker)
+
+	for _, arg := range fnDeclStmt.Args {
+		if _, ok := comp.scopes.Declare(arg.Value); !ok {
+			return fmt.Errorf("variable %s already declared", arg.Value)
+		}
+	}
+
+	if err := comp.compileBlockNode(&functionTarget, fnDeclStmt.Body); err != nil {
+		return err
+	}
+
+	comp.scopes.Exit()
+
+	comp.funcs = append(comp.funcs, &functionTarget)
+
+	return nil
+}
+
 func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
 	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
@@ -313,6 +379,26 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 	return nil
 }
 
+func (comp *Compiler) compileReturnStmt(t *code.Builder, returnStmt ast.StmtReturn) error {
+	// Check that we are in fact in a function
+	functionScope := comp.scopes.CurrentFunction()
+	if functionScope.data.(ScopeFunction).IsRootScope() {
+		return fmt.Errorf("can't return when not inside a function" + functionScope.data.(ScopeFunction).unit.String())
+	}
+
+	if returnStmt.Value.IsEmpty() {
+		t.AppendOp(code.OpPushNull)
+	} else {
+		if err := comp.compileExpr(t, returnStmt.Value); err != nil {
+			return err
+		}
+	}
+
+	t.AppendOp(code.OpRet)
+
+	return nil
+}
+
 func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
@@ -500,15 +586,21 @@ func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) erro
 		return fmt.Errorf("undefined symbol %s", expr.Value.Value)
 	}
 
-	if symbolId.symbolKind != SymbolKindVariable {
-		return fmt.Errorf("%v values are not implemeted yet", symbolId.symbolKind)
-	}
+	// TODO: Add other ways how the symbol should be fetched. (local, env, global, etc.)
+	switch symbolId.symbolKind {
+	case SymbolKindVariable:
+		symbol := comp.scopes.GetVariable(symbolId)
 
-	symbol := comp.scopes.GetVariable(symbolId)
+		t.AppendOp(code.OpGetLocal)
+		t.AppendInt(int64(symbol.data.localIndex))
+	case SymbolKindFunction:
+		symbol := comp.scopes.GetFunction(symbolId)
 
-	// TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.)
-	t.AppendOp(code.OpGetLocal)
-	t.AppendInt(int64(symbol.data.localIndex))
+		t.AppendOp(code.OpPushFunction)
+		t.AppendMarkerReference(symbol.data.marker)
+	default:
+		panic(fmt.Errorf("unknown symbol kind: %v", symbolId.symbolKind))
+	}
 
 	return nil
 }
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index f3a20a5..cd62088 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -350,6 +350,54 @@ func TestForIn(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestSimpleFunction(t *testing.T) {
+	src := `
+	var result = the_meaning_of_life()
+
+	fn the_meaning_of_life() {
+		return 42
+	}
+	`
+
+	expected := `
+	push_function @the_meaning_of_life
+	call 0
+	halt
+
+	@the_meaning_of_life:
+	push_int 42
+	ret
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
+func TestFunctionArgs(t *testing.T) {
+	src := `
+	fn add(a, b) {
+		return a + b
+	}
+	
+	add(4, 5)
+	`
+
+	expected := `
+	push_function @add
+	push_int 4
+	push_int 5
+	call 2
+	halt
+
+	@add:
+	get_local 0
+	get_local 1
+	add
+	ret
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func mustCompileTo(t *testing.T, src, expected string) {
 	scanner := scanner.New(strings.NewReader(src))
 	tokens, err := scanner.Scan()
diff --git a/pkg/lang/compiler/scope_chain.go b/pkg/lang/compiler/scope_chain.go
index ad176da..3a3b819 100644
--- a/pkg/lang/compiler/scope_chain.go
+++ b/pkg/lang/compiler/scope_chain.go
@@ -14,7 +14,7 @@ type ScopeChain struct {
 
 func NewScopeChain() ScopeChain {
 	scopes := make([]Scope, 1)
-	scopes[0] = NewFunctionScope("") // Top-most scope is a function scope, so it can have sub-units
+	scopes[0] = NewFunctionScope(0, "") // Top-most scope is a function scope, so it can have sub-units
 
 	return ScopeChain{
 		nameToSymbol: make(map[string]SymbolID),
@@ -26,13 +26,17 @@ func (sc *ScopeChain) CurrentScopeID() ScopeID {
 	return ScopeID(len(sc.scopes) - 1)
 }
 
+func (sc *ScopeChain) IsRootScope() bool {
+	return sc.CurrentScopeID() == ScopeID(0)
+}
+
 func (sc *ScopeChain) Current() *Scope {
 	return &sc.scopes[sc.CurrentScopeID()]
 }
 
 func (sc *ScopeChain) CurrentFunction() *Scope {
 	// TODO: Probably should make this lookup constant by making a seperate array of function scopes
-	for i := len(sc.scopes) - 1; i <= 0; i++ {
+	for i := len(sc.scopes) - 1; i >= 0; i++ {
 		if sc.scopes[i].kind == ScopeKindFunction {
 			return &sc.scopes[i]
 		}
@@ -45,8 +49,9 @@ func (sc *ScopeChain) Enter() {
 	sc.scopes = append(sc.scopes, NewNormalScope())
 }
 
-func (sc *ScopeChain) EnterFunction(unitName string) {
-	sc.scopes = append(sc.scopes, NewFunctionScope(unitName))
+func (sc *ScopeChain) EnterFunction(unit code.Marker) {
+	id := sc.CurrentScopeID() + 1
+	sc.scopes = append(sc.scopes, NewFunctionScope(id, unit))
 }
 
 func (sc *ScopeChain) Exit() {
@@ -148,7 +153,7 @@ func (sc *ScopeChain) CreateFunctionSubUnit(subUnitName string) code.Marker {
 	fnScope := sc.CurrentFunction()
 	data := fnScope.data.(ScopeFunction)
 
-	name := data.unitName
+	name := data.unit
 	if name == "" {
 		name = code.Marker(subUnitName)
 	} else {
@@ -174,12 +179,12 @@ func (sc *ScopeChain) GetVariable(id SymbolID) Symbol[SymbolVariable] {
 	return sc.scopes[id.scopeID].variableSymbols[id.indexInScope]
 }
 
-func (sc *ScopeChain) GetFunction(id SymbolID) Symbol[SymbolVariable] {
-	if id.symbolKind != SymbolKindVariable {
+func (sc *ScopeChain) GetFunction(id SymbolID) Symbol[SymbolFunction] {
+	if id.symbolKind != SymbolKindFunction {
 		panic("incorrect symbol id kind given")
 	}
 
-	return sc.scopes[id.scopeID].variableSymbols[id.indexInScope]
+	return sc.scopes[id.scopeID].functionSymbols[id.indexInScope]
 }
 
 type SymbolID struct {
@@ -213,13 +218,14 @@ func NewNormalScope() Scope {
 	}
 }
 
-func NewFunctionScope(unitName string) Scope {
+func NewFunctionScope(id ScopeID, unit code.Marker) Scope {
 	return Scope{
 		variableSymbols: make([]Symbol[SymbolVariable], 0),
 		functionSymbols: make([]Symbol[SymbolFunction], 0),
 		kind:            ScopeKindFunction,
 		data: ScopeFunction{
-			unitName:     code.Marker(unitName),
+			id:           id,
+			unit:         unit,
 			subUnitCount: 0,
 		},
 	}
@@ -228,10 +234,19 @@ func NewFunctionScope(unitName string) Scope {
 type ScopeNormal struct{}
 
 type ScopeFunction struct {
-	unitName     code.Marker
+	id           ScopeID
+	unit         code.Marker
 	subUnitCount int
 }
 
+func (sf ScopeFunction) ID() ScopeID {
+	return sf.id
+}
+
+func (sf ScopeFunction) IsRootScope() bool {
+	return sf.ID() == ScopeID(0)
+}
+
 type ScopeLoop struct {
 	breakMarker    code.Marker
 	continueMarker code.Marker
diff --git a/pkg/lang/vm/code/builder.go b/pkg/lang/vm/code/builder.go
index adb2eed..f413ba1 100644
--- a/pkg/lang/vm/code/builder.go
+++ b/pkg/lang/vm/code/builder.go
@@ -68,7 +68,7 @@ func (b *Builder) AppendBuilder(other Builder) error {
 			}
 		}
 
-		b.SetMarker(marker, pc)
+		b.SetMarker(marker, b.Len()+pc)
 	}
 
 	for pc, marker := range other.markerRefs {