package compiler import ( "fmt" "jinx/pkg/lang/ast" "jinx/pkg/lang/vm/code" ) type Compiler struct { ast ast.Program scopes ScopeChain } func New(ast ast.Program) *Compiler { return &Compiler{ ast: ast, scopes: NewScopeChain(), } } func (comp *Compiler) Compile() (code.Code, error) { target := code.NewBuilder() for _, stmt := range comp.ast.Stmts { if err := comp.compileStmt(&target, stmt); err != nil { return code.Code{}, err } } target.AppendOp(code.OpHalt) return target.Build(), nil } func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { var err error switch stmt.Kind { case ast.StmtKindEmpty: // Do nothing. case ast.StmtKindVarDecl: decl := stmt.Value.(ast.StmtVarDecl) err = comp.compileVarDeclStmt(t, decl) case ast.StmtKindIf: ifstmt := stmt.Value.(ast.StmtIf) err = comp.compileIfStmt(t, ifstmt) case ast.StmtKindExpr: expr := stmt.Value.(ast.StmtExpr).Value err = comp.compileExpr(t, expr) default: panic(fmt.Errorf("statement of kind %v not implemented", stmt.Kind)) } return err } func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error { if !comp.scopes.Declare(decl.Name.Value) { return fmt.Errorf("variable %s already declared", decl.Name.Value) } if err := comp.compileExpr(t, decl.Value); err != nil { return err } return nil } func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error { // An if statement is composed out of CondNodes. // A cond node can be either the top `if` branch, and `elif` branch // inbetween, or `else` branch at the bottom. // `if` and `elif` are identical, but the Cond expr of an `else` node is an empty ast.Expr. // // Each compiled CondNode consists of 4 parts. // 1. Condition check => Compiled Cond expr, pushes a bool onto the stack // Example: `push_false` for `if false {}` // 2. Condition jump => If the condition is false, jump to the **next** CondNode, if not last. // Example: `jf @elif` or `jf @end` // 3. Then block => Anything the user wants to execute. // Example: `push_int 1` or something // 4. Then jump => Since the condition was true, we have to jump to the end of the CondNode list, // preventing other CondNodes from running. This is missing from the last CondNode. // Example: `jmp @end` subUnits := make([]code.Builder, 0, len(ifStmt.Conds)) totalLength := 0 jmpLength := 9 // The length of either of the jump parts: op: 1 + uint: 8 = 9 for i, cond := range ifStmt.Conds { // Then block thenTarget := code.NewBuilder() if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil { return err } totalLength += thenTarget.Len() if i != len(ifStmt.Conds)-1 { totalLength += jmpLength } // Condition check conditionTarget := code.NewBuilder() if !cond.Cond.IsEmpty() { if err := comp.compileExpr(&conditionTarget, cond.Cond); err != nil { return err } totalLength += conditionTarget.Len() + jmpLength // condjmp conditionTarget.AppendOp(code.OpJf) // Condition jump conditionTarget.AppendReferenceToPc(int64(totalLength)) } subUnit := conditionTarget subUnit.AppendBuilder(thenTarget) subUnits = append(subUnits, subUnit) } result := code.NewBuilder() // Then jumps for i, subUnit := range subUnits { if i != len(ifStmt.Conds)-1 { subUnit.AppendOp(code.OpJmp) subUnit.AppendReferenceToPc(int64(totalLength)) } result.AppendBuilderWithoutAdjustingReferences(subUnit) } t.AppendBuilder(result) return nil } func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error { switch expr.Kind { case ast.ExprKindBinary: return comp.compileBinaryExpr(t, expr.Value.(ast.ExprBinary)) case ast.ExprKindUnary: return comp.compileUnaryExpr(t, expr.Value.(ast.ExprUnary)) case ast.ExprKindCall: return comp.compileCallExpr(t, expr.Value.(ast.ExprCall)) case ast.ExprKindSubscription: return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription)) case ast.ExprKindGroup: return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup)) case ast.ExprKindFnLit: panic("not implemented") case ast.ExprKindArrayLit: panic("not implemented") case ast.ExprKindIdent: return comp.compileIdentExpr(t, expr.Value.(ast.ExprIdent)) case ast.ExprKindIntLit: return comp.compileIntLitExpr(t, expr.Value.(ast.ExprIntLit)) case ast.ExprKindFloatLit: return comp.compileFloatLitExpr(t, expr.Value.(ast.ExprFloatLit)) case ast.ExprKindStringLit: return comp.compileStringLitExpr(t, expr.Value.(ast.ExprStringLit)) case ast.ExprKindBoolLit: return comp.compileBoolLitExpr(t, expr.Value.(ast.ExprBoolLit)) case ast.ExprKindNullLit: return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit)) case ast.ExprKindThis: panic("not implemented") default: panic("unknown expression kind") } } func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Op == ast.BinOpAssign { return comp.compileAssignExpr(t, expr) } if err := comp.compileExpr(t, expr.Left); err != nil { return err } if err := comp.compileExpr(t, expr.Right); err != nil { return err } switch expr.Op { case ast.BinOpPlus: t.AppendOp(code.OpAdd) case ast.BinOpMinus: t.AppendOp(code.OpSub) case ast.BinOpStar: // t.AppendOp(code.OpMul) panic("not implemented") case ast.BinOpSlash: // t.AppendOp(code.OpDiv) panic("not implemented") case ast.BinOpPercent: // t.AppendOp(code.OpMod) panic("not implemented") case ast.BinOpEq: // t.AppendOp(code.OpEq) panic("not implemented") case ast.BinOpNeq: // t.AppendOp(code.OpNeq) panic("not implemented") case ast.BinOpLt: // t.AppendOp(code.OpLt) panic("not implemented") case ast.BinOpLte: t.AppendOp(code.OpLte) case ast.BinOpGt: // t.AppendOp(code.OpGt) panic("not implemented") case ast.BinOpGte: // t.AppendOp(code.OpGte) panic("not implemented") default: panic("unknown binary operator") } return nil } func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Left.Kind != ast.ExprKindIdent { return fmt.Errorf("lvalues other than identifiers not implemented") } name := expr.Left.Value.(ast.ExprIdent).Value.Value symbol, ok := comp.scopes.Lookup(name) if !ok { return fmt.Errorf("variable %s not declared", name) } if err := comp.compileExpr(t, expr.Right); err != nil { return err } t.AppendOp(code.OpSetLocal) t.AppendInt(int64(symbol.localIndex)) return nil } func (comp *Compiler) compileUnaryExpr(t *code.Builder, expr ast.ExprUnary) error { if err := comp.compileExpr(t, expr.Value); err != nil { return err } switch expr.Op { case ast.UnOpBang: panic("not implemented") case ast.UnOpMinus: panic("not implemented") default: panic("unknown unary operator") } return nil } func (comp *Compiler) compileCallExpr(t *code.Builder, expr ast.ExprCall) error { if err := comp.compileExpr(t, expr.Callee); err != nil { return err } for i := 0; i < len(expr.Args); i++ { if err := comp.compileExpr(t, expr.Args[i]); err != nil { return err } } t.AppendOp(code.OpCall) t.AppendInt(int64(len(expr.Args))) return nil } func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubscription) error { if err := comp.compileExpr(t, expr.Obj); err != nil { return err } if err := comp.compileExpr(t, expr.Key); err != nil { return err } t.AppendOp(code.OpIndex) return nil } func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error { return comp.compileExpr(t, expr.Value) } func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) error { symbol, ok := comp.scopes.Lookup(expr.Value.Value) if !ok { return fmt.Errorf("undefined symbol %s", expr.Value.Value) } // TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(symbol.localIndex)) return nil } func (comp *Compiler) compileIntLitExpr(t *code.Builder, expr ast.ExprIntLit) error { t.AppendOp(code.OpPushInt) t.AppendInt(int64(expr.Value)) return nil } func (comp *Compiler) compileFloatLitExpr(t *code.Builder, expr ast.ExprFloatLit) error { t.AppendOp(code.OpPushFloat) t.AppendFloat(expr.Value) return nil } func (comp *Compiler) compileStringLitExpr(t *code.Builder, expr ast.ExprStringLit) error { t.AppendOp(code.OpPushString) t.AppendString(expr.Value) return nil } func (comp *Compiler) compileBoolLitExpr(t *code.Builder, expr ast.ExprBoolLit) error { if expr.Value { t.AppendOp(code.OpPushTrue) } else { t.AppendOp(code.OpPushFalse) } return nil } func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit) error { t.AppendOp(code.OpPushNull) return nil } func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error { for _, stmt := range block.Stmts { if err := comp.compileStmt(t, stmt); err != nil { return err } } return nil }