about summary refs log tree commit diff
path: root/pkg/lang/compiler/compiler.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lang/compiler/compiler.go')
-rw-r--r--pkg/lang/compiler/compiler.go114
1 files changed, 103 insertions, 11 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index a31ade7..ae2a71c 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -9,18 +9,29 @@ import (
 type Compiler struct {
 	ast ast.Program
 
+	funcs []*code.Builder
+
 	scopes ScopeChain
 }
 
 func New(ast ast.Program) *Compiler {
 	return &Compiler{
-		ast: ast,
-
+		ast:    ast,
+		funcs:  make([]*code.Builder, 0),
 		scopes: NewScopeChain(),
 	}
 }
 
 func (comp *Compiler) Compile() (code.Code, error) {
+	// Pre-declare all top-level functions
+	for _, stmt := range comp.ast.Stmts {
+		if stmt.Kind == ast.StmtKindFnDecl {
+			if err := comp.preDeclareFunction(stmt.Value.(ast.StmtFnDecl)); err != nil {
+				return code.Code{}, err
+			}
+		}
+	}
+
 	target := code.NewBuilder()
 
 	for _, stmt := range comp.ast.Stmts {
@@ -31,9 +42,21 @@ func (comp *Compiler) Compile() (code.Code, error) {
 
 	target.AppendOp(code.OpHalt)
 
+	for _, function := range comp.funcs {
+		target.AppendBuilder(*function)
+	}
+
 	return target.Build()
 }
 
+func (comp *Compiler) preDeclareFunction(fnDeclStmt ast.StmtFnDecl) error {
+	if _, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value); !ok {
+		return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value)
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	var err error
 	switch stmt.Kind {
@@ -42,7 +65,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	case ast.StmtKindUse:
 		panic("use statements not implemented")
 	case ast.StmtKindFnDecl:
-		panic("function declaration statements not implemented")
+		fnDeclStmt := stmt.Value.(ast.StmtFnDecl)
+		err = comp.compileFnDeclStmt(t, fnDeclStmt)
 	case ast.StmtKindObjectDecl:
 		panic("object declaration statements not implemented")
 	case ast.StmtKindVarDecl:
@@ -60,7 +84,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	case ast.StmtKindTry:
 		panic("try statements not implemented")
 	case ast.StmtKindReturn:
-		panic("return statements not implemented")
+		returnStmt := stmt.Value.(ast.StmtReturn)
+		err = comp.compileReturnStmt(t, returnStmt)
 	case ast.StmtKindContinue:
 		panic("continue statements not implemented")
 	case ast.StmtKindBreak:
@@ -77,6 +102,47 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	return err
 }
 
+func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDecl) error {
+	marker, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value)
+	if !ok {
+		// If we are in the root scope, the function was simply predeclared :)
+		if comp.scopes.IsRootScope() {
+			symbolID, ok := comp.scopes.Lookup(fnDeclStmt.Name.Value)
+			if !ok {
+				panic("function said it was declared but apparently it was lying")
+			}
+
+			symbol := comp.scopes.GetFunction(symbolID)
+
+			marker = symbol.data.marker
+		} else {
+			return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value)
+		}
+	}
+
+	functionTarget := code.NewBuilder()
+
+	functionTarget.PutMarker(marker)
+
+	comp.scopes.EnterFunction(marker)
+
+	for _, arg := range fnDeclStmt.Args {
+		if _, ok := comp.scopes.Declare(arg.Value); !ok {
+			return fmt.Errorf("variable %s already declared", arg.Value)
+		}
+	}
+
+	if err := comp.compileBlockNode(&functionTarget, fnDeclStmt.Body); err != nil {
+		return err
+	}
+
+	comp.scopes.Exit()
+
+	comp.funcs = append(comp.funcs, &functionTarget)
+
+	return nil
+}
+
 func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
 	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
@@ -313,6 +379,26 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn)
 	return nil
 }
 
+func (comp *Compiler) compileReturnStmt(t *code.Builder, returnStmt ast.StmtReturn) error {
+	// Check that we are in fact in a function
+	functionScope := comp.scopes.CurrentFunction()
+	if functionScope.data.(ScopeFunction).IsRootScope() {
+		return fmt.Errorf("can't return when not inside a function" + functionScope.data.(ScopeFunction).unit.String())
+	}
+
+	if returnStmt.Value.IsEmpty() {
+		t.AppendOp(code.OpPushNull)
+	} else {
+		if err := comp.compileExpr(t, returnStmt.Value); err != nil {
+			return err
+		}
+	}
+
+	t.AppendOp(code.OpRet)
+
+	return nil
+}
+
 func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
@@ -500,15 +586,21 @@ func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) erro
 		return fmt.Errorf("undefined symbol %s", expr.Value.Value)
 	}
 
-	if symbolId.symbolKind != SymbolKindVariable {
-		return fmt.Errorf("%v values are not implemeted yet", symbolId.symbolKind)
-	}
+	// TODO: Add other ways how the symbol should be fetched. (local, env, global, etc.)
+	switch symbolId.symbolKind {
+	case SymbolKindVariable:
+		symbol := comp.scopes.GetVariable(symbolId)
 
-	symbol := comp.scopes.GetVariable(symbolId)
+		t.AppendOp(code.OpGetLocal)
+		t.AppendInt(int64(symbol.data.localIndex))
+	case SymbolKindFunction:
+		symbol := comp.scopes.GetFunction(symbolId)
 
-	// TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.)
-	t.AppendOp(code.OpGetLocal)
-	t.AppendInt(int64(symbol.data.localIndex))
+		t.AppendOp(code.OpPushFunction)
+		t.AppendMarkerReference(symbol.data.marker)
+	default:
+		panic(fmt.Errorf("unknown symbol kind: %v", symbolId.symbolKind))
+	}
 
 	return nil
 }