package compiler import ( "fmt" "jinx/pkg/lang/ast" "jinx/pkg/lang/vm/code" ) type Compiler struct { ast ast.Program scopes ScopeChain } func New(ast ast.Program) *Compiler { return &Compiler{ ast: ast, scopes: NewScopeChain(), } } func (comp *Compiler) Compile() (code.Code, error) { target := code.NewBuilder() for _, stmt := range comp.ast.Stmts { if err := comp.compileStmt(&target, stmt); err != nil { return code.Code{}, err } } target.AppendOp(code.OpHalt) return target.Build(), nil } func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error { var err error switch stmt.Kind { case ast.StmtKindEmpty: // Do nothing. case ast.StmtKindVarDecl: decl := stmt.Value.(ast.StmtVarDecl) err = comp.compileVarDeclStmt(t, decl) case ast.StmtKindIf: ifstmt := stmt.Value.(ast.StmtIf) err = comp.compileIfStmt(t, ifstmt) case ast.StmtKindExpr: expr := stmt.Value.(ast.StmtExpr).Value err = comp.compileExpr(t, expr) default: panic(fmt.Errorf("statement of kind %v not implemented", stmt.Kind)) } return err } func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error { if !comp.scopes.Declare(decl.Name.Value) { return fmt.Errorf("variable %s already declared", decl.Name.Value) } if err := comp.compileExpr(t, decl.Value); err != nil { return err } return nil } func (comp *Compiler) compileIfStmt(t *code.Builder, ifstmt ast.StmtIf) error { // push_false -> cond // only on ifs and elifs // jf @elif -> condjmp // only on ifs and elifs // push_int 1 -> then // jmp @end -> thenjmp // except on last cond subUnits := make([]code.Builder, 0, len(ifstmt.Conds)) totalLength := t.Len() jmpLength := 9 // OP + Uint for i, cond := range ifstmt.Conds { thenTarget := code.NewBuilder() // then if err := comp.compileBlockNode(&thenTarget, cond.Then); err != nil { return err } totalLength += thenTarget.Len() if i != len(ifstmt.Conds)-1 { totalLength += jmpLength // thenjmp } subUnitTarget := code.NewBuilder() if !cond.Cond.IsEmpty() { // cond if err := comp.compileExpr(&subUnitTarget, cond.Cond); err != nil { return err } totalLength += subUnitTarget.Len() + jmpLength // condjmp subUnitTarget.AppendOp(code.OpJf) subUnitTarget.AppendInt(int64(totalLength)) } subUnitTarget.AppendRaw(thenTarget.Code()) subUnits = append(subUnits, subUnitTarget) } for i, subUnit := range subUnits { if i != len(ifstmt.Conds)-1 { subUnit.AppendOp(code.OpJmp) subUnit.AppendInt(int64(totalLength)) } t.AppendRaw(subUnit.Code()) } return nil } func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error { switch expr.Kind { case ast.ExprKindBinary: return comp.compileBinaryExpr(t, expr.Value.(ast.ExprBinary)) case ast.ExprKindUnary: return comp.compileUnaryExpr(t, expr.Value.(ast.ExprUnary)) case ast.ExprKindCall: return comp.compileCallExpr(t, expr.Value.(ast.ExprCall)) case ast.ExprKindSubscription: return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription)) case ast.ExprKindGroup: return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup)) case ast.ExprKindFnLit: panic("not implemented") case ast.ExprKindArrayLit: panic("not implemented") case ast.ExprKindIdent: return comp.compileIdentExpr(t, expr.Value.(ast.ExprIdent)) case ast.ExprKindIntLit: return comp.compileIntLitExpr(t, expr.Value.(ast.ExprIntLit)) case ast.ExprKindFloatLit: return comp.compileFloatLitExpr(t, expr.Value.(ast.ExprFloatLit)) case ast.ExprKindStringLit: return comp.compileStringLitExpr(t, expr.Value.(ast.ExprStringLit)) case ast.ExprKindBoolLit: return comp.compileBoolLitExpr(t, expr.Value.(ast.ExprBoolLit)) case ast.ExprKindNullLit: return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit)) case ast.ExprKindThis: panic("not implemented") default: panic("unknown expression kind") } } func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Op == ast.BinOpAssign { return comp.compileAssignExpr(t, expr) } if err := comp.compileExpr(t, expr.Left); err != nil { return err } if err := comp.compileExpr(t, expr.Right); err != nil { return err } switch expr.Op { case ast.BinOpPlus: t.AppendOp(code.OpAdd) case ast.BinOpMinus: t.AppendOp(code.OpSub) case ast.BinOpStar: // t.AppendOp(code.OpMul) panic("not implemented") case ast.BinOpSlash: // t.AppendOp(code.OpDiv) panic("not implemented") case ast.BinOpPercent: // t.AppendOp(code.OpMod) panic("not implemented") case ast.BinOpEq: // t.AppendOp(code.OpEq) panic("not implemented") case ast.BinOpNeq: // t.AppendOp(code.OpNeq) panic("not implemented") case ast.BinOpLt: // t.AppendOp(code.OpLt) panic("not implemented") case ast.BinOpLte: t.AppendOp(code.OpLte) case ast.BinOpGt: // t.AppendOp(code.OpGt) panic("not implemented") case ast.BinOpGte: // t.AppendOp(code.OpGte) panic("not implemented") default: panic("unknown binary operator") } return nil } func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error { if expr.Left.Kind != ast.ExprKindIdent { return fmt.Errorf("lvalues other than identifiers not implemented") } name := expr.Left.Value.(ast.ExprIdent).Value.Value symbol, ok := comp.scopes.Lookup(name) if !ok { return fmt.Errorf("variable %s not declared", name) } if err := comp.compileExpr(t, expr.Right); err != nil { return err } t.AppendOp(code.OpSetLocal) t.AppendInt(int64(symbol.localIndex)) return nil } func (comp *Compiler) compileUnaryExpr(t *code.Builder, expr ast.ExprUnary) error { if err := comp.compileExpr(t, expr.Value); err != nil { return err } switch expr.Op { case ast.UnOpBang: panic("not implemented") case ast.UnOpMinus: panic("not implemented") default: panic("unknown unary operator") } return nil } func (comp *Compiler) compileCallExpr(t *code.Builder, expr ast.ExprCall) error { if err := comp.compileExpr(t, expr.Callee); err != nil { return err } for i := 0; i < len(expr.Args); i++ { if err := comp.compileExpr(t, expr.Args[i]); err != nil { return err } } t.AppendOp(code.OpCall) t.AppendInt(int64(len(expr.Args))) return nil } func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubscription) error { if err := comp.compileExpr(t, expr.Obj); err != nil { return err } if err := comp.compileExpr(t, expr.Key); err != nil { return err } t.AppendOp(code.OpIndex) return nil } func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error { return comp.compileExpr(t, expr.Value) } func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) error { symbol, ok := comp.scopes.Lookup(expr.Value.Value) if !ok { return fmt.Errorf("undefined symbol %s", expr.Value.Value) } // TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.) t.AppendOp(code.OpGetLocal) t.AppendInt(int64(symbol.localIndex)) return nil } func (comp *Compiler) compileIntLitExpr(t *code.Builder, expr ast.ExprIntLit) error { t.AppendOp(code.OpPushInt) t.AppendInt(int64(expr.Value)) return nil } func (comp *Compiler) compileFloatLitExpr(t *code.Builder, expr ast.ExprFloatLit) error { t.AppendOp(code.OpPushFloat) t.AppendFloat(expr.Value) return nil } func (comp *Compiler) compileStringLitExpr(t *code.Builder, expr ast.ExprStringLit) error { t.AppendOp(code.OpPushString) t.AppendString(expr.Value) return nil } func (comp *Compiler) compileBoolLitExpr(t *code.Builder, expr ast.ExprBoolLit) error { if expr.Value { t.AppendOp(code.OpPushTrue) } else { t.AppendOp(code.OpPushFalse) } return nil } func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit) error { t.AppendOp(code.OpPushNull) return nil } func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error { for _, stmt := range block.Stmts { if err := comp.compileStmt(t, stmt); err != nil { return err } } return nil }