about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pkg/lang/compiler/compiler.go132
-rw-r--r--pkg/lang/compiler/compiler_test.go54
-rw-r--r--pkg/lang/compiler/scope_chain.go56
-rw-r--r--pkg/lang/vm/code/op.go1
-rw-r--r--pkg/lang/vm/exec.go57
-rw-r--r--pkg/lang/vm/text/decompiler.go1
-rw-r--r--pkg/lang/vm/text/op.go1
-rw-r--r--pkg/lang/vm/vm.go2
8 files changed, 277 insertions, 27 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index b1fd961..3a5efd9 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -48,6 +48,9 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 	case ast.StmtKindForCond:
 		forCondStmt := stmt.Value.(ast.StmtForCond)
 		err = comp.compileForCondStmt(t, forCondStmt)
+	case ast.StmtKindForIn:
+		forCondIn := stmt.Value.(ast.StmtForIn)
+		err = comp.compileForInStmt(t, forCondIn)
 	case ast.StmtKindExpr:
 		expr := stmt.Value.(ast.StmtExpr).Value
 		err = comp.compileExpr(t, expr)
@@ -59,14 +62,14 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 }
 
 func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
-	if !comp.scopes.Declare(decl.Name.Value) {
-		return fmt.Errorf("variable %s already declared", decl.Name.Value)
-	}
-
 	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
 	}
 
+	if _, ok := comp.scopes.Declare(decl.Name.Value); !ok {
+		return fmt.Errorf("variable %s already declared", decl.Name.Value)
+	}
+
 	return nil
 }
 
@@ -90,7 +93,6 @@ func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
 	subUnits := make([]code.Builder, 0, len(ifStmt.Conds))
 
 	totalLength := 0
-	jmpLength := 9 // The length of either of the jump parts: op: 1 + uint: 8 = 9
 
 	for i, cond := range ifStmt.Conds {
 		// Then block
@@ -101,7 +103,7 @@ func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
 
 		totalLength += thenTarget.Len()
 		if i != len(ifStmt.Conds)-1 {
-			totalLength += jmpLength
+			totalLength += lengthOfAJumpInstruction
 		}
 
 		// Condition check
@@ -111,7 +113,7 @@ func (comp *Compiler) compileIfStmt(t *code.Builder, ifStmt ast.StmtIf) error {
 				return err
 			}
 
-			totalLength += conditionTarget.Len() + jmpLength // condjmp
+			totalLength += conditionTarget.Len() + lengthOfAJumpInstruction // condjmp
 
 			conditionTarget.AppendOp(code.OpJf)
 			// Condition jump
@@ -153,8 +155,6 @@ func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtFo
 		return err
 	}
 
-	jmpLength := 9
-
 	conditionTarget := code.NewBuilder()
 	if !forCondStmt.Cond.IsEmpty() {
 		// Condition check
@@ -162,7 +162,7 @@ func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtFo
 			return err
 		}
 
-		endOfFor := conditionTarget.Len() + doTarget.Len() + jmpLength*2
+		endOfFor := conditionTarget.Len() + doTarget.Len() + lengthOfAJumpInstruction*2
 
 		// Condition jump
 		conditionTarget.AppendOp(code.OpJf)
@@ -181,6 +181,111 @@ func (comp *Compiler) compileForCondStmt(t *code.Builder, forCondStmt ast.StmtFo
 	return nil
 }
 
+func (comp *Compiler) compileForInStmt(t *code.Builder, forInStmt ast.StmtForIn) error {
+	// Mostly same as ForCond, but the condition is implicit.
+
+	// Example for: `for x in [] {}`
+
+	// 0. Preparation
+	// push_array # collection stored in local 0
+	// push_int 0 # i stored in local 1
+	// push_null # x stored in local 2
+	// 1. Condition check (i < x.length())
+	// @check:
+	// get_local 1
+	// get_local 0
+	// get_member "$length"
+	// call 0
+	// lt
+	// 2. Condition jump
+	// jf @end
+	// 3.1 Do preparation (aka setting the x variable)
+	// get_local 0
+	// get_local 1
+	// index
+	// set_local 2
+	// 3. Do block
+	// ...
+	// 4. Repeat jump:
+	// jmp @check
+	// @end:
+	// halt
+
+	// Preparation
+	preparationTarget := code.NewBuilder()
+	if err := comp.compileExpr(&preparationTarget, forInStmt.Collection); err != nil {
+		return err
+	}
+	collectionLocal := comp.scopes.DeclareAnonymous()
+
+	preparationTarget.AppendOp(code.OpPushInt)
+	preparationTarget.AppendInt(0)
+	iLocal := comp.scopes.DeclareAnonymous()
+
+	preparationTarget.AppendOp(code.OpPushNull)
+	nameLocal, ok := comp.scopes.Declare(forInStmt.Name.Value)
+	if !ok {
+		return fmt.Errorf("variable %s already declared", forInStmt.Name.Value)
+	}
+
+	// Condition check
+	conditionTarget := code.NewBuilder()
+
+	conditionTarget.AppendOp(code.OpGetLocal)
+	conditionTarget.AppendInt(int64(iLocal))
+
+	conditionTarget.AppendOp(code.OpGetLocal)
+	conditionTarget.AppendInt(int64(collectionLocal))
+
+	conditionTarget.AppendOp(code.OpGetMember)
+	conditionTarget.AppendString("length")
+
+	conditionTarget.AppendOp(code.OpCall)
+	conditionTarget.AppendInt(0)
+
+	conditionTarget.AppendOp(code.OpLt)
+
+	// Do Preparation
+	doPreparationTarget := code.NewBuilder()
+
+	doPreparationTarget.AppendOp(code.OpGetLocal)
+	doPreparationTarget.AppendInt(int64(collectionLocal))
+
+	doPreparationTarget.AppendOp(code.OpGetLocal)
+	doPreparationTarget.AppendInt(int64(iLocal))
+
+	doPreparationTarget.AppendOp(code.OpIndex)
+
+	doPreparationTarget.AppendOp(code.OpSetLocal)
+	doPreparationTarget.AppendInt(int64(nameLocal))
+
+	// Do block
+	doTarget := code.NewBuilder()
+	if err := comp.compileBlockNode(&doTarget, forInStmt.Do); err != nil {
+		return err
+	}
+
+	// Condition Jump
+
+	endOfFor := preparationTarget.Len() + conditionTarget.Len() + doPreparationTarget.Len() + doTarget.Len() + lengthOfAJumpInstruction*2
+
+	conditionTarget.AppendOp(code.OpJf)
+	conditionTarget.AppendReferenceToPc(int64(endOfFor))
+
+	subUnit := preparationTarget
+	subUnit.AppendBuilderWithoutAdjustingReferences(conditionTarget)
+	subUnit.AppendBuilder(doPreparationTarget)
+	subUnit.AppendBuilder(doTarget)
+
+	// Repeat jump
+	subUnit.AppendOp(code.OpJmp)
+	subUnit.AppendReferenceToPc(int64(preparationTarget.Len()))
+
+	t.AppendBuilder(subUnit)
+
+	return nil
+}
+
 func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	switch expr.Kind {
 	case ast.ExprKindBinary:
@@ -252,8 +357,7 @@ func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) er
 		// t.AppendOp(code.OpNeq)
 		panic("not implemented")
 	case ast.BinOpLt:
-		// t.AppendOp(code.OpLt)
-		panic("not implemented")
+		t.AppendOp(code.OpLt)
 	case ast.BinOpLte:
 		t.AppendOp(code.OpLte)
 	case ast.BinOpGt:
@@ -417,3 +521,7 @@ func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) err
 
 	return nil
 }
+
+const (
+	lengthOfAJumpInstruction = 9 // The length of a jump Op (jmp, jf, jt) and it's following 64-bit integer.
+)
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index 4fd4bac..326ad68 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -291,6 +291,60 @@ func TestForCond(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestForIn(t *testing.T) {
+	src := `
+	for x in [1, 2, 3] {
+		"say"(x)
+	}
+	`
+
+	expected := `
+	push_array
+
+	get_local 0
+	get_member "$push"
+	push_int 1
+	call 1
+
+	get_local 0
+	get_member "$push"
+	push_int 2
+	call 1
+
+	get_local 0
+	get_member "$push"
+	push_int 3
+	call 1
+
+	push_int 0
+	push_null
+
+	@check:
+	get_local 1
+	get_local 0
+	get_member "$length"
+	call 0
+	lt
+
+	jf @end
+
+	get_local 0
+	get_local 1
+	index
+	set_local 2
+
+	push_string "say"
+	get_local 2
+	call 1
+
+	jmp @check
+	@end:
+	halt
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func mustCompileTo(t *testing.T, src, expected string) {
 	scanner := scanner.New(strings.NewReader(src))
 	tokens, err := scanner.Scan()
diff --git a/pkg/lang/compiler/scope_chain.go b/pkg/lang/compiler/scope_chain.go
index 8d942ea..6b7e693 100644
--- a/pkg/lang/compiler/scope_chain.go
+++ b/pkg/lang/compiler/scope_chain.go
@@ -9,8 +9,9 @@ type ScopeChain struct {
 func NewScopeChain() ScopeChain {
 	scopes := make([]Scope, 1)
 	scopes[0] = Scope{
-		kind:    ScopeKindGlobal,
-		symbols: make(map[string]Symbol),
+		kind:         ScopeKindGlobal,
+		nameToSymbol: make(map[string]int),
+		symbols:      make([]Symbol, 0),
 	}
 
 	return ScopeChain{
@@ -24,8 +25,9 @@ func (sc *ScopeChain) Current() *Scope {
 
 func (sc *ScopeChain) Enter(kind ScopeKind) {
 	sc.scopes = append(sc.scopes, Scope{
-		kind:    kind,
-		symbols: make(map[string]Symbol),
+		kind:         kind,
+		nameToSymbol: make(map[string]int),
+		symbols:      make([]Symbol, 0),
 	})
 }
 
@@ -34,28 +36,51 @@ func (sc *ScopeChain) Exit() {
 	sc.scopes = sc.scopes[:len(sc.scopes)-1]
 }
 
-func (sc *ScopeChain) Declare(name string) bool {
+func (sc *ScopeChain) Declare(name string) (int, bool) {
 	// Check whether the symbol is already declared in any of the scopes.
 	for _, scope := range sc.scopes {
-		if _, ok := scope.symbols[name]; ok {
-			return false
+		if _, ok := scope.nameToSymbol[name]; ok {
+			return 0, false
 		}
 	}
 
+	current := sc.Current()
+	index := len(current.symbols)
+
 	// Declare the symbol in the current scope.
-	sc.Current().symbols[name] = Symbol{
+	current.symbols = append(current.symbols, Symbol{
 		kind:       SymbolKindVariable,
 		name:       name,
-		localIndex: len(sc.Current().symbols),
-	}
+		localIndex: index,
+	})
+
+	current.nameToSymbol[name] = index
+
+	return index, true
+}
+
+func (sc *ScopeChain) DeclareAnonymous() int {
+	current := sc.Current()
+	index := len(current.symbols)
+
+	// Declare the symbol in the current scope.
+	current.symbols = append(current.symbols, Symbol{
+		kind:       SymbolKindVariable,
+		name:       "",
+		localIndex: index,
+	})
+
+	return index
+}
 
-	return true
+func (sc *ScopeChain) DeclareTemporary() int {
+	return len(sc.Current().symbols)
 }
 
 func (sc *ScopeChain) Lookup(name string) (Symbol, bool) {
 	for i := len(sc.scopes) - 1; i >= 0; i-- {
-		if symbol, ok := sc.scopes[i].symbols[name]; ok {
-			return symbol, true
+		if symbol, ok := sc.scopes[i].nameToSymbol[name]; ok {
+			return sc.scopes[i].symbols[symbol], true
 		}
 	}
 
@@ -71,6 +96,7 @@ const (
 )
 
 type Scope struct {
-	kind    ScopeKind
-	symbols map[string]Symbol
+	kind         ScopeKind
+	nameToSymbol map[string]int
+	symbols      []Symbol
 }
diff --git a/pkg/lang/vm/code/op.go b/pkg/lang/vm/code/op.go
index 8b8ff3a..d0b2555 100644
--- a/pkg/lang/vm/code/op.go
+++ b/pkg/lang/vm/code/op.go
@@ -35,6 +35,7 @@ const (
 	OpSub
 	OpMod
 	OpIndex
+	OpLt
 	OpLte
 	OpCall
 
diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go
index 93b8845..181d74e 100644
--- a/pkg/lang/vm/exec.go
+++ b/pkg/lang/vm/exec.go
@@ -662,6 +662,63 @@ func (vm *VM) execIndex() error {
 	return nil
 }
 
+func (vm *VM) execLt() error {
+	y, err := vm.popAndDrop()
+	if err != nil {
+		return err
+	}
+	x, err := vm.popAndDrop()
+	if err != nil {
+		return err
+	}
+
+	var res value.Value
+
+	switch x.Type() {
+	case value.IntType:
+		xv := x.Data().(value.IntData).Get()
+		switch y.Type() {
+		case value.IntType:
+			yv := y.Data().(value.IntData).Get()
+			res = value.NewBool(xv < yv)
+		case value.FloatType:
+			yv := y.Data().(value.FloatData).Get()
+			res = value.NewBool(float64(xv) < yv)
+		default:
+			return ErrInvalidOperandTypes{
+				Op: code.OpLte,
+				X:  x.Type(),
+				Y:  y.Type(),
+			}
+		}
+	case value.FloatType:
+		xv := x.Data().(value.FloatData).Get()
+		switch y.Type() {
+		case value.IntType:
+			yv := y.Data().(value.IntData).Get()
+			res = value.NewBool(xv < float64(yv))
+		case value.FloatType:
+			yv := y.Data().(value.FloatData).Get()
+			res = value.NewBool(xv < yv)
+		default:
+			return ErrInvalidOperandTypes{
+				Op: code.OpLte,
+				X:  x.Type(),
+				Y:  y.Type(),
+			}
+		}
+	default:
+		return ErrInvalidOperandTypes{
+			Op: code.OpLte,
+			X:  x.Type(),
+			Y:  y.Type(),
+		}
+	}
+
+	vm.stack.Push(res)
+	return nil
+}
+
 func (vm *VM) execLte() error {
 	y, err := vm.popAndDrop()
 	if err != nil {
diff --git a/pkg/lang/vm/text/decompiler.go b/pkg/lang/vm/text/decompiler.go
index c8922ec..bef066b 100644
--- a/pkg/lang/vm/text/decompiler.go
+++ b/pkg/lang/vm/text/decompiler.go
@@ -61,6 +61,7 @@ func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) {
 		code.OpSub,
 		code.OpMod,
 		code.OpIndex,
+		code.OpLt,
 		code.OpLte,
 		code.OpRet,
 		code.OpTempArrLen,
diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go
index 0d01bdb..17f4847 100644
--- a/pkg/lang/vm/text/op.go
+++ b/pkg/lang/vm/text/op.go
@@ -30,6 +30,7 @@ var (
 		code.OpSub:          "sub",
 		code.OpMod:          "mod",
 		code.OpIndex:        "index",
+		code.OpLt:          "lt",
 		code.OpLte:          "lte",
 		code.OpCall:         "call",
 		code.OpJmp:          "jmp",
diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go
index e483176..3f0703a 100644
--- a/pkg/lang/vm/vm.go
+++ b/pkg/lang/vm/vm.go
@@ -168,6 +168,8 @@ func (vm *VM) step(op code.Op) (stepDecision, error) {
 		err = vm.execMod()
 	case code.OpIndex:
 		err = vm.execIndex()
+	case code.OpLt:
+		err = vm.execLt()
 	case code.OpLte:
 		err = vm.execLte()
 	case code.OpCall: