package compiler import ( "fmt" "jinx/pkg/lang/ast" "jinx/pkg/lang/compiler/scope" "jinx/pkg/lang/vm/code" ) type Compiler struct { ast ast.Program funcs []*code.Builder scopes scope.ScopeChain } func New(ast ast.Program) *Compiler { return &Compiler{ ast: ast, funcs: make([]*code.Builder, 0), scopes: scope.NewScopeChain(), } } func (comp *Compiler) Compile() (code.Code, error) { // Pre-declare all top-level functions for _, stmt := range comp.ast.Stmts { if stmt.Kind == ast.StmtKindFnDecl { if err := comp.preDeclareFunction(stmt.Value.(ast.StmtFnDecl)); err != nil { return code.Code{}, err } } } target := code.NewBuilder() for _, stmt := range comp.ast.Stmts { if err := comp.compileStmt(&target, stmt); err != nil { return code.Code{}, err } } target.AppendOp(code.OpHalt) for _, function := range comp.funcs { target.AppendBuilder(*function) } return target.Build() } func (comp *Compiler) preDeclareFunction(fnDeclStmt ast.StmtFnDecl) error { if _, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value, uint(len(fnDeclStmt.Args))); !ok { return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value) } return nil } func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { var err error switch stmt.Kind { case ast.StmtKindEmpty: // Do nothing. case ast.StmtKindUse: panic("use statements not implemented") case ast.StmtKindFnDecl: fnDeclStmt := stmt.Value.(ast.StmtFnDecl) err = comp.compileFnDeclStmt(t, fnDeclStmt) case ast.StmtKindObjectDecl: panic("object declaration statements not implemented") case ast.StmtKindVarDecl: decl := stmt.Value.(ast.StmtVarDecl) err = comp.compileVarDeclStmt(t, decl) case ast.StmtKindIf: ifStmt := stmt.Value.(ast.StmtIf) err = comp.compileIfStmt(t, ifStmt) case ast.StmtKindForCond: forCondStmt := stmt.Value.(ast.StmtForCond) err = comp.compileForCondStmt(t, forCondStmt) case ast.StmtKindForIn: forCondIn := stmt.Value.(ast.StmtForIn) err = comp.compileForInStmt(t, forCondIn) case ast.StmtKindTry: panic("try statements not implemented") case ast.StmtKindReturn: returnStmt := stmt.Value.(ast.StmtReturn) err = comp.compileReturnStmt(t, returnStmt) case ast.StmtKindContinue: panic("continue statements not implemented") case ast.StmtKindBreak: panic("break statements not implemented") case ast.StmtKindThrow: panic("throw statements not implemented") case ast.StmtKindExpr: expr := stmt.Value.(ast.StmtExpr).Value err = comp.compileExpr(t, expr) default: panic(fmt.Errorf("unknown statement kind: %d", stmt.Kind)) } return err } func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDecl) error { marker, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value, uint(len(fnDeclStmt.Args))) if !ok { // If we are in the root scope, the function was simply predeclared :) if comp.scopes.IsRootScope() { symbolID, ok := comp.scopes.Lookup(fnDeclStmt.Name.Value) if !ok { panic("function said it was declared but apparently it was lying") } symbol := comp.scopes.GetFunction(symbolID) marker = symbol.Data().Marker() } else { return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value) } } functionTarget := code.NewBuilder() functionTarget.PutMarker(marker) comp.scopes.EnterFunction(marker) for _, arg := range fnDeclStmt.Args { if _, ok := comp.scopes.Declare(arg.Value); !ok { return fmt.Errorf("variable %s already declared", arg.Value) } } if err := comp.compileBlockNode(&functionTarget, fnDeclStmt.Body); err != nil { return err } comp.scopes.Exit() comp.funcs = append(comp.funcs, &functionTarget) return nil } func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error { if err := comp.compileExpr(t, decl.Value); err != nil { return err } if _, ok := comp.scopes.Declare(decl.Name.Value); !ok { return fmt.Errorf("variable %s already declared", decl.Name.Value) } return nil } func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error { // An if statement is composed out of CondNodes. // A cond node can be either the top `if` branch, and `elif` branch // inbetween, or `else` branch at the bottom. // `if` and `elif` are identical, but the Cond expr of an `else` node is an empty ast.Expr. // // Each compiled CondNode consists of 4 parts. // 1. Condition check => Compiled Cond expr, pushes a bool onto the stack // Example: `push_false` for `if false {}` // 2. Condition jump => If the condition is false, jump to the **next** CondNode, if not last. // Example: `jf @elif` or `jf @end` // 3. Then block => Anything the user wants to execute. // Example: `push_int 1` or something // 4. Then jump => Since the condition was true, we have to jump to the end of the CondNode list, // preventing other CondNodes from running. This is missing from the last CondNode. // Example: `jmp @end` // First we create all the markers we'll need for the if statement parentMarker := comp.scopes.CreateAnonymousFunctionSubUnit() endMarker := parentMarker.SubMarker("end") condMarkers := make([]code.Marker, 0, len(ifStmt.Conds)-1) // We don't need a marker for the first CondNode. for i := 0; i < len(ifStmt.Conds)-1; i++ { condMarker := parentMarker.SubMarker("cond_%d", i+1) condMarkers = append(condMarkers, condMarker) } for i, cond := range ifStmt.Conds { isFirst := i == 0 isLast := i == len(ifStmt.Conds)-1 // If we aren't in the first CondNode, the node before it needs a marker to here. if !isFirst { marker := condMarkers[i-1] t.PutMarker(marker) } if !cond.Cond.IsEmpty() { // Condition check if err := comp.compileExpr(t, cond.Cond); err != nil { return err } // Condition jump t.AppendOp(code.OpJf) if isLast { t.AppendMarkerReference(endMarker) } else { nextCondMarker := condMarkers[i] t.AppendMarkerReference(nextCondMarker) } } // Then block if err := comp.compileBlockNode(t, cond.Then); err != nil { return err } // Then jump if !isLast { t.AppendOp(code.OpJmp) t.AppendMarkerReference(endMarker) } } t.PutMarker(endMarker) return nil } func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtForCond) error { // Parts: // 1. Condition check: Decides whether the loop should run // 2. Condition jump: Jumps to the end of the for if condition was false // 3. Do block: Does something // 4. Repeat jump: Jumps back to start parentMarker := comp.scopes.CreateAnonymousFunctionSubUnit() startMarker := parentMarker.SubMarker("start") endMarker := parentMarker.SubMarker("end") t.PutMarker(startMarker) if !forCondStmt.Cond.IsEmpty() { // Condition check if err := comp.compileExpr(t, forCondStmt.Cond); err != nil { return err } // Condition check t.AppendOp(code.OpJf) t.AppendMarkerReference(endMarker) } // Do block if err := comp.compileBlockNode(t, forCondStmt.Do); err != nil { return err } // Repeat jump t.AppendOp(code.OpJmp) t.AppendMarkerReference(startMarker) t.PutMarker(endMarker) return nil } func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn) error { // Mostly same as ForCond, but the condition is implicit. // Example for: `for x in [] {}` // 0. Preparation // push_array # collection stored in local 0 // push_int 0 # i stored in local 1 // push_null # x stored in local 2 // 1. Condition check (i < x.length()) // @check: // get_local 1 // get_local 0 // get_member "$length" // call 0 // lt // 2. Condition jump // jf @end // 3.1 Do preparation (aka setting the x variable) // get_local 0 // get_local 1 // index // set_local 2 // get_local 1 // push_int 1 // add // set_local 1 // 3. Do block // ... // 4. Repeat jump: // jmp @check // @end: // halt parentMarker := comp.scopes.CreateAnonymousFunctionSubUnit() checkMarker := parentMarker.SubUnit("check") endMarker := parentMarker.SubUnit("end") // Preparation if err := comp.compileExpr(t, forInStmt.Collection); err != nil { return err } collectionLocal := comp.scopes.DeclareAnonymous() t.AppendOp(code.OpPushInt) t.AppendInt(0) iLocal := comp.scopes.DeclareAnonymous() t.AppendOp(code.OpPushNull) nameLocal, ok := comp.scopes.Declare(forInStmt.Name.Value) if !ok { return fmt.Errorf("variable %s already declared", forInStmt.Name.Value) } // Condition check t.PutMarker(checkMarker) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(iLocal)) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(collectionLocal)) t.AppendOp(code.OpGetMember) t.AppendString("length") t.AppendOp(code.OpCall) t.AppendInt(0) t.AppendOp(code.OpLt) // Condition jump t.AppendOp(code.OpJf) t.AppendMarkerReference(endMarker) // Do Preparation t.AppendOp(code.OpGetLocal) t.AppendInt(int64(collectionLocal)) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(iLocal)) t.AppendOp(code.OpIndex) t.AppendOp(code.OpSetLocal) t.AppendInt(int64(nameLocal)) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(iLocal)) t.AppendOp(code.OpPushInt) t.AppendInt(1) t.AppendOp(code.OpAdd) t.AppendOp(code.OpSetLocal) t.AppendInt(int64(iLocal)) // Do block if err := comp.compileBlockNode(t, forInStmt.Do); err != nil { return err } // Repeat jump t.AppendOp(code.OpJmp) t.AppendMarkerReference(checkMarker) t.PutMarker(endMarker) return nil } func (comp *Compiler) compileReturnStmt(t *code.Builder, returnStmt ast.StmtReturn) error { // Check that we are in fact in a function functionScope := comp.scopes.CurrentFunction() if functionScope.IsRootScope() { return fmt.Errorf("can't return when not inside a function") } if returnStmt.Value.IsEmpty() { t.AppendOp(code.OpPushNull) } else { if err := comp.compileExpr(t, returnStmt.Value); err != nil { return err } } t.AppendOp(code.OpRet) return nil } func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error { switch expr.Kind { case ast.ExprKindBinary: return comp.compileBinaryExpr(t, expr.Value.(ast.ExprBinary)) case ast.ExprKindUnary: return comp.compileUnaryExpr(t, expr.Value.(ast.ExprUnary)) case ast.ExprKindCall: return comp.compileCallExpr(t, expr.Value.(ast.ExprCall)) case ast.ExprKindSubscription: return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription)) case ast.ExprKindGroup: return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup)) case ast.ExprKindFnLit: panic("not implemented") case ast.ExprKindArrayLit: return comp.compileArrayLitExpr(t, expr.Value.(ast.ExprArrayLit)) case ast.ExprKindIdent: return comp.compileIdentExpr(t, expr.Value.(ast.ExprIdent)) case ast.ExprKindIntLit: return comp.compileIntLitExpr(t, expr.Value.(ast.ExprIntLit)) case ast.ExprKindFloatLit: return comp.compileFloatLitExpr(t, expr.Value.(ast.ExprFloatLit)) case ast.ExprKindStringLit: return comp.compileStringLitExpr(t, expr.Value.(ast.ExprStringLit)) case ast.ExprKindBoolLit: return comp.compileBoolLitExpr(t, expr.Value.(ast.ExprBoolLit)) case ast.ExprKindNullLit: return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit)) case ast.ExprKindThis: panic("not implemented") default: panic("unknown expression kind") } } func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Op == ast.BinOpAssign { return comp.compileAssignExpr(t, expr) } if err := comp.compileExpr(t, expr.Left); err != nil { return err } if err := comp.compileExpr(t, expr.Right); err != nil { return err } switch expr.Op { case ast.BinOpPlus: t.AppendOp(code.OpAdd) case ast.BinOpMinus: t.AppendOp(code.OpSub) case ast.BinOpStar: t.AppendOp(code.OpMul) case ast.BinOpSlash: t.AppendOp(code.OpDiv) case ast.BinOpPercent: t.AppendOp(code.OpMod) case ast.BinOpEq: t.AppendOp(code.OpEq) case ast.BinOpNeq: //t.AppendOp(code.OpNeq) panic("not implemented") case ast.BinOpLt: t.AppendOp(code.OpLt) case ast.BinOpLte: t.AppendOp(code.OpLte) case ast.BinOpGt: t.AppendOp(code.OpGt) case ast.BinOpGte: t.AppendOp(code.OpGte) default: panic("unknown binary operator") } return nil } func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Left.Kind != ast.ExprKindIdent { return fmt.Errorf("lvalues other than identifiers not implemented") } name := expr.Left.Value.(ast.ExprIdent).Value.Value symbolId, ok := comp.scopes.Lookup(name) if !ok { return fmt.Errorf("variable %s not declared", name) } if symbolId.SymbolKind() != scope.SymbolKindVariable { return fmt.Errorf("can't assign to a %v", symbolId.SymbolKind()) } symbol := comp.scopes.GetVariable(symbolId) if err := comp.compileExpr(t, expr.Right); err != nil { return err } t.AppendOp(code.OpSetLocal) t.AppendInt(int64(symbol.Data().LocalIndex())) return nil } func (comp *Compiler) compileUnaryExpr(t *code.Builder, expr ast.ExprUnary) error { if err := comp.compileExpr(t, expr.Value); err != nil { return err } switch expr.Op { case ast.UnOpBang: panic("not implemented") case ast.UnOpMinus: panic("not implemented") default: panic("unknown unary operator") } return nil } func (comp *Compiler) compileCallExpr(t *code.Builder, expr ast.ExprCall) error { if err := comp.compileExpr(t, expr.Callee); err != nil { return err } for i := 0; i < len(expr.Args); i++ { if err := comp.compileExpr(t, expr.Args[i]); err != nil { return err } } t.AppendOp(code.OpCall) t.AppendInt(int64(len(expr.Args))) return nil } func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubscription) error { if err := comp.compileExpr(t, expr.Obj); err != nil { return err } if err := comp.compileExpr(t, expr.Key); err != nil { return err } t.AppendOp(code.OpIndex) return nil } func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error { return comp.compileExpr(t, expr.Value) } func (comp *Compiler) compileArrayLitExpr(t *code.Builder, expr ast.ExprArrayLit) error { t.AppendOp(code.OpPushArray) arrayLocal := comp.scopes.DeclareTemporary() for _, value := range expr.Values { t.AppendOp(code.OpGetLocal) t.AppendInt(int64(arrayLocal)) t.AppendOp(code.OpGetMember) t.AppendString("push") if err := comp.compileExpr(t, value); err != nil { return err } t.AppendOp(code.OpCall) t.AppendInt(1) } return nil } func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) error { symbolId, ok := comp.scopes.Lookup(expr.Value.Value) if !ok { return fmt.Errorf("undefined symbol %s", expr.Value.Value) } // TODO: Add other ways how the symbol should be fetched. (local, env, global, etc.) switch symbolId.SymbolKind() { case scope.SymbolKindVariable: symbol := comp.scopes.GetVariable(symbolId) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(symbol.Data().LocalIndex())) case scope.SymbolKindFunction: symbol := comp.scopes.GetFunction(symbolId) t.AppendOp(code.OpPushFunction) t.AppendMarkerReference(symbol.Data().Marker()) if symbol.Data().Args() != 0 { t.AppendOp(code.OpSetArgCount) t.AppendInt(int64(symbol.Data().Args())) } default: panic(fmt.Errorf("unknown symbol kind: %v", symbolId.SymbolKind())) } return nil } func (comp *Compiler) compileIntLitExpr(t *code.Builder, expr ast.ExprIntLit) error { t.AppendOp(code.OpPushInt) t.AppendInt(int64(expr.Value)) return nil } func (comp *Compiler) compileFloatLitExpr(t *code.Builder, expr ast.ExprFloatLit) error { t.AppendOp(code.OpPushFloat) t.AppendFloat(expr.Value) return nil } func (comp *Compiler) compileStringLitExpr(t *code.Builder, expr ast.ExprStringLit) error { t.AppendOp(code.OpPushString) t.AppendString(expr.Value) return nil } func (comp *Compiler) compileBoolLitExpr(t *code.Builder, expr ast.ExprBoolLit) error { if expr.Value { t.AppendOp(code.OpPushTrue) } else { t.AppendOp(code.OpPushFalse) } return nil } func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit) error { t.AppendOp(code.OpPushNull) return nil } func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error { for _, stmt := range block.Stmts { if err := comp.compileStmt(t, stmt); err != nil { return err } } return nil }