From 7a60e91c8c45663127cfa3aa31a14c930f717aeb Mon Sep 17 00:00:00 2001 From: Mel Date: Mon, 11 Jul 2022 03:00:21 +0200 Subject: Function compilation and minor fixes --- pkg/lang/compiler/compiler.go | 114 +++++++++++++++++++++++++++++++++---- pkg/lang/compiler/compiler_test.go | 48 ++++++++++++++++ pkg/lang/compiler/scope_chain.go | 37 ++++++++---- pkg/lang/vm/code/builder.go | 2 +- 4 files changed, 178 insertions(+), 23 deletions(-) (limited to 'pkg') diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go index a31ade7..ae2a71c 100644 --- a/pkg/lang/compiler/compiler.go +++ b/pkg/lang/compiler/compiler.go @@ -9,18 +9,29 @@ import ( type Compiler struct { ast ast.Program + funcs []*code.Builder + scopes ScopeChain } func New(ast ast.Program) *Compiler { return &Compiler{ - ast: ast, - + ast: ast, + funcs: make([]*code.Builder, 0), scopes: NewScopeChain(), } } func (comp *Compiler) Compile() (code.Code, error) { + // Pre-declare all top-level functions + for _, stmt := range comp.ast.Stmts { + if stmt.Kind == ast.StmtKindFnDecl { + if err := comp.preDeclareFunction(stmt.Value.(ast.StmtFnDecl)); err != nil { + return code.Code{}, err + } + } + } + target := code.NewBuilder() for _, stmt := range comp.ast.Stmts { @@ -31,9 +42,21 @@ func (comp *Compiler) Compile() (code.Code, error) { target.AppendOp(code.OpHalt) + for _, function := range comp.funcs { + target.AppendBuilder(*function) + } + return target.Build() } +func (comp *Compiler) preDeclareFunction(fnDeclStmt ast.StmtFnDecl) error { + if _, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value); !ok { + return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value) + } + + return nil +} + func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { var err error switch stmt.Kind { @@ -42,7 +65,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { case ast.StmtKindUse: panic("use statements not implemented") case ast.StmtKindFnDecl: - panic("function declaration statements not implemented") + fnDeclStmt := stmt.Value.(ast.StmtFnDecl) + err = comp.compileFnDeclStmt(t, fnDeclStmt) case ast.StmtKindObjectDecl: panic("object declaration statements not implemented") case ast.StmtKindVarDecl: @@ -60,7 +84,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { case ast.StmtKindTry: panic("try statements not implemented") case ast.StmtKindReturn: - panic("return statements not implemented") + returnStmt := stmt.Value.(ast.StmtReturn) + err = comp.compileReturnStmt(t, returnStmt) case ast.StmtKindContinue: panic("continue statements not implemented") case ast.StmtKindBreak: @@ -77,6 +102,47 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { return err } +func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDecl) error { + marker, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value) + if !ok { + // If we are in the root scope, the function was simply predeclared :) + if comp.scopes.IsRootScope() { + symbolID, ok := comp.scopes.Lookup(fnDeclStmt.Name.Value) + if !ok { + panic("function said it was declared but apparently it was lying") + } + + symbol := comp.scopes.GetFunction(symbolID) + + marker = symbol.data.marker + } else { + return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value) + } + } + + functionTarget := code.NewBuilder() + + functionTarget.PutMarker(marker) + + comp.scopes.EnterFunction(marker) + + for _, arg := range fnDeclStmt.Args { + if _, ok := comp.scopes.Declare(arg.Value); !ok { + return fmt.Errorf("variable %s already declared", arg.Value) + } + } + + if err := comp.compileBlockNode(&functionTarget, fnDeclStmt.Body); err != nil { + return err + } + + comp.scopes.Exit() + + comp.funcs = append(comp.funcs, &functionTarget) + + return nil +} + func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error { if err := comp.compileExpr(t, decl.Value); err != nil { return err @@ -313,6 +379,26 @@ func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn) return nil } +func (comp *Compiler) compileReturnStmt(t *code.Builder, returnStmt ast.StmtReturn) error { + // Check that we are in fact in a function + functionScope := comp.scopes.CurrentFunction() + if functionScope.data.(ScopeFunction).IsRootScope() { + return fmt.Errorf("can't return when not inside a function" + functionScope.data.(ScopeFunction).unit.String()) + } + + if returnStmt.Value.IsEmpty() { + t.AppendOp(code.OpPushNull) + } else { + if err := comp.compileExpr(t, returnStmt.Value); err != nil { + return err + } + } + + t.AppendOp(code.OpRet) + + return nil +} + func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error { switch expr.Kind { case ast.ExprKindBinary: @@ -500,15 +586,21 @@ func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) erro return fmt.Errorf("undefined symbol %s", expr.Value.Value) } - if symbolId.symbolKind != SymbolKindVariable { - return fmt.Errorf("%v values are not implemeted yet", symbolId.symbolKind) - } + // TODO: Add other ways how the symbol should be fetched. (local, env, global, etc.) + switch symbolId.symbolKind { + case SymbolKindVariable: + symbol := comp.scopes.GetVariable(symbolId) - symbol := comp.scopes.GetVariable(symbolId) + t.AppendOp(code.OpGetLocal) + t.AppendInt(int64(symbol.data.localIndex)) + case SymbolKindFunction: + symbol := comp.scopes.GetFunction(symbolId) - // TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.) - t.AppendOp(code.OpGetLocal) - t.AppendInt(int64(symbol.data.localIndex)) + t.AppendOp(code.OpPushFunction) + t.AppendMarkerReference(symbol.data.marker) + default: + panic(fmt.Errorf("unknown symbol kind: %v", symbolId.symbolKind)) + } return nil } diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go index f3a20a5..cd62088 100644 --- a/pkg/lang/compiler/compiler_test.go +++ b/pkg/lang/compiler/compiler_test.go @@ -350,6 +350,54 @@ func TestForIn(t *testing.T) { mustCompileTo(t, src, expected) } +func TestSimpleFunction(t *testing.T) { + src := ` + var result = the_meaning_of_life() + + fn the_meaning_of_life() { + return 42 + } + ` + + expected := ` + push_function @the_meaning_of_life + call 0 + halt + + @the_meaning_of_life: + push_int 42 + ret + ` + + mustCompileTo(t, src, expected) +} + +func TestFunctionArgs(t *testing.T) { + src := ` + fn add(a, b) { + return a + b + } + + add(4, 5) + ` + + expected := ` + push_function @add + push_int 4 + push_int 5 + call 2 + halt + + @add: + get_local 0 + get_local 1 + add + ret + ` + + mustCompileTo(t, src, expected) +} + func mustCompileTo(t *testing.T, src, expected string) { scanner := scanner.New(strings.NewReader(src)) tokens, err := scanner.Scan() diff --git a/pkg/lang/compiler/scope_chain.go b/pkg/lang/compiler/scope_chain.go index ad176da..3a3b819 100644 --- a/pkg/lang/compiler/scope_chain.go +++ b/pkg/lang/compiler/scope_chain.go @@ -14,7 +14,7 @@ type ScopeChain struct { func NewScopeChain() ScopeChain { scopes := make([]Scope, 1) - scopes[0] = NewFunctionScope("") // Top-most scope is a function scope, so it can have sub-units + scopes[0] = NewFunctionScope(0, "") // Top-most scope is a function scope, so it can have sub-units return ScopeChain{ nameToSymbol: make(map[string]SymbolID), @@ -26,13 +26,17 @@ func (sc *ScopeChain) CurrentScopeID() ScopeID { return ScopeID(len(sc.scopes) - 1) } +func (sc *ScopeChain) IsRootScope() bool { + return sc.CurrentScopeID() == ScopeID(0) +} + func (sc *ScopeChain) Current() *Scope { return &sc.scopes[sc.CurrentScopeID()] } func (sc *ScopeChain) CurrentFunction() *Scope { // TODO: Probably should make this lookup constant by making a seperate array of function scopes - for i := len(sc.scopes) - 1; i <= 0; i++ { + for i := len(sc.scopes) - 1; i >= 0; i++ { if sc.scopes[i].kind == ScopeKindFunction { return &sc.scopes[i] } @@ -45,8 +49,9 @@ func (sc *ScopeChain) Enter() { sc.scopes = append(sc.scopes, NewNormalScope()) } -func (sc *ScopeChain) EnterFunction(unitName string) { - sc.scopes = append(sc.scopes, NewFunctionScope(unitName)) +func (sc *ScopeChain) EnterFunction(unit code.Marker) { + id := sc.CurrentScopeID() + 1 + sc.scopes = append(sc.scopes, NewFunctionScope(id, unit)) } func (sc *ScopeChain) Exit() { @@ -148,7 +153,7 @@ func (sc *ScopeChain) CreateFunctionSubUnit(subUnitName string) code.Marker { fnScope := sc.CurrentFunction() data := fnScope.data.(ScopeFunction) - name := data.unitName + name := data.unit if name == "" { name = code.Marker(subUnitName) } else { @@ -174,12 +179,12 @@ func (sc *ScopeChain) GetVariable(id SymbolID) Symbol[SymbolVariable] { return sc.scopes[id.scopeID].variableSymbols[id.indexInScope] } -func (sc *ScopeChain) GetFunction(id SymbolID) Symbol[SymbolVariable] { - if id.symbolKind != SymbolKindVariable { +func (sc *ScopeChain) GetFunction(id SymbolID) Symbol[SymbolFunction] { + if id.symbolKind != SymbolKindFunction { panic("incorrect symbol id kind given") } - return sc.scopes[id.scopeID].variableSymbols[id.indexInScope] + return sc.scopes[id.scopeID].functionSymbols[id.indexInScope] } type SymbolID struct { @@ -213,13 +218,14 @@ func NewNormalScope() Scope { } } -func NewFunctionScope(unitName string) Scope { +func NewFunctionScope(id ScopeID, unit code.Marker) Scope { return Scope{ variableSymbols: make([]Symbol[SymbolVariable], 0), functionSymbols: make([]Symbol[SymbolFunction], 0), kind: ScopeKindFunction, data: ScopeFunction{ - unitName: code.Marker(unitName), + id: id, + unit: unit, subUnitCount: 0, }, } @@ -228,10 +234,19 @@ func NewFunctionScope(unitName string) Scope { type ScopeNormal struct{} type ScopeFunction struct { - unitName code.Marker + id ScopeID + unit code.Marker subUnitCount int } +func (sf ScopeFunction) ID() ScopeID { + return sf.id +} + +func (sf ScopeFunction) IsRootScope() bool { + return sf.ID() == ScopeID(0) +} + type ScopeLoop struct { breakMarker code.Marker continueMarker code.Marker diff --git a/pkg/lang/vm/code/builder.go b/pkg/lang/vm/code/builder.go index adb2eed..f413ba1 100644 --- a/pkg/lang/vm/code/builder.go +++ b/pkg/lang/vm/code/builder.go @@ -68,7 +68,7 @@ func (b *Builder) AppendBuilder(other Builder) error { } } - b.SetMarker(marker, pc) + b.SetMarker(marker, b.Len()+pc) } for pc, marker := range other.markerRefs { -- cgit 1.4.1