about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--pkg/lang/vm/code/op.go2
-rw-r--r--pkg/lang/vm/errors.go11
-rw-r--r--pkg/lang/vm/exec.go81
-rw-r--r--pkg/lang/vm/stack.go106
-rw-r--r--pkg/lang/vm/text/compiler_test.go6
-rw-r--r--pkg/lang/vm/text/op.go2
-rw-r--r--pkg/lang/vm/vm.go15
-rw-r--r--pkg/lang/vm/vm_test.go30
8 files changed, 138 insertions, 115 deletions
diff --git a/pkg/lang/vm/code/op.go b/pkg/lang/vm/code/op.go
index cb2811a..84d468f 100644
--- a/pkg/lang/vm/code/op.go
+++ b/pkg/lang/vm/code/op.go
@@ -16,13 +16,13 @@ const (
 	OpPushFunction
 	OpPushObject
 
+	OpShift
 	OpDrop
 
 	OpGetGlobal
 	OpGetLocal
 	OpGetMember
 	OpGetMethod
-	OpGetArg
 	OpGetEnv
 
 	OpAdd
diff --git a/pkg/lang/vm/errors.go b/pkg/lang/vm/errors.go
index 045f9a9..bbc66b8 100644
--- a/pkg/lang/vm/errors.go
+++ b/pkg/lang/vm/errors.go
@@ -20,12 +20,13 @@ func (e Error) Error() string {
 // Fatal errors
 
 var (
-	ErrCallStackOverflow   = errors.New("call stack overflow (max depth: 1000)")
-	ErrLocalStackOverflow  = errors.New("local stack overflow (max depth: 1000)")
-	ErrNoPreviousCallFrame = errors.New("no previous call frame")
-	ErrCantPopRootFrame    = errors.New("cannot pop root frame")
+	ErrStackOverflow  = errors.New("stack overflow (max depth: 1000)")
+	ErrStackUnderflow = errors.New("local stack underflow")
 
-	ErrCallFrameEmpty = errors.New("current call frame is empty")
+	ErrReachedMaxCallDepth  = errors.New("reached max call depth (max depth: 1000)")
+	ErrReachedRootCallFrame = errors.New("reached root call frame")
+
+	ErrCallBaseCantBeNegative = errors.New("call base cannot be negative")
 )
 
 type ErrLocalIndexOutOfBounds struct {
diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go
index 1b94572..4f37ad3 100644
--- a/pkg/lang/vm/exec.go
+++ b/pkg/lang/vm/exec.go
@@ -6,64 +6,45 @@ import (
 )
 
 func (vm *VM) execPushInt(x int64) {
-	vm.stack.Top().Push(value.NewInt(x))
+	vm.stack.Push(value.NewInt(x))
 }
 
 func (vm *VM) execPushFloat(x float64) {
-	vm.stack.Top().Push(value.NewFloat(x))
+	vm.stack.Push(value.NewFloat(x))
 }
 
 func (vm *VM) execPushString(str string) {
-	vm.stack.Top().Push(value.NewString(str))
+	vm.stack.Push(value.NewString(str))
 }
 
 func (vm *VM) execPushBool(b bool) {
-	vm.stack.Top().Push(value.NewBool(b))
+	vm.stack.Push(value.NewBool(b))
 }
 
 func (vm *VM) execPushNull() {
-	vm.stack.Top().Push(value.NewNull())
+	vm.stack.Push(value.NewNull())
 }
 
 func (vm *VM) execPushArray() {
-	vm.stack.Top().Push(value.NewArray([]value.Value{}))
+	vm.stack.Push(value.NewArray([]value.Value{}))
 }
 
 func (vm *VM) execGetLocal(offset int) error {
-	top := vm.stack.Top()
-
-	local, err := top.At(int(offset))
+	local, err := vm.stack.Local(int(offset))
 	if err != nil {
 		return err
 	}
 
-	top.Push(local)
-	return nil
-}
-
-func (vm *VM) execGetArg() error {
-	prev, err := vm.stack.Prev()
-	if err != nil {
-		return err
-	}
-
-	arg, err := prev.Pop()
-	if err != nil {
-		return err
-	}
-
-	vm.stack.Top().Push(arg)
+	vm.stack.Push(local)
 	return nil
 }
 
 func (vm *VM) execAdd() error {
-	top := vm.stack.Top()
-
-	x, err := top.Pop()
+	x, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
-	y, err := top.Pop()
+	y, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
@@ -122,18 +103,16 @@ func (vm *VM) execAdd() error {
 		}
 	}
 
-	top.Push(res)
+	vm.stack.Push(res)
 	return nil
 }
 
 func (vm *VM) execSub() error {
-	top := vm.stack.Top()
-
-	x, err := top.Pop()
+	x, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
-	y, err := top.Pop()
+	y, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
@@ -181,18 +160,16 @@ func (vm *VM) execSub() error {
 		}
 	}
 
-	top.Push(res)
+	vm.stack.Push(res)
 	return nil
 }
 
 func (vm *VM) execIndex() error {
-	top := vm.stack.Top()
-
-	v, err := top.Pop()
+	v, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
-	i, err := top.Pop()
+	i, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
@@ -209,7 +186,7 @@ func (vm *VM) execIndex() error {
 					Len:   arr.Len(),
 				}
 			}
-			top.Push(arr.At(int(idx)))
+			vm.stack.Push(arr.At(int(idx)))
 		default:
 			return ErrInvalidOperandTypes{
 				Op: code.OpIndex,
@@ -229,13 +206,11 @@ func (vm *VM) execIndex() error {
 }
 
 func (vm *VM) execLte() error {
-	top := vm.stack.Top()
-
-	x, err := top.Pop()
+	x, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
-	y, err := top.Pop()
+	y, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
@@ -283,14 +258,12 @@ func (vm *VM) execLte() error {
 		}
 	}
 
-	top.Push(res)
+	vm.stack.Push(res)
 	return nil
 }
 
 func (vm *VM) execJumpIf(pc int, cond bool) error {
-	top := vm.stack.Top()
-
-	b, err := top.Pop()
+	b, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
@@ -319,9 +292,7 @@ func (vm *VM) execJumpIf(pc int, cond bool) error {
 }
 
 func (vm *VM) execTempArrLen() error {
-	top := vm.stack.Top()
-
-	a, err := top.Pop()
+	a, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
@@ -330,7 +301,7 @@ func (vm *VM) execTempArrLen() error {
 	case value.ArrayType:
 		arr := a.Data().(value.ArrayData)
 		res := value.NewInt(int64(arr.Len()))
-		top.Push(res)
+		vm.stack.Push(res)
 	default:
 		return ErrInvalidOperandTypes{
 			Op: code.OpTempArrLen,
@@ -342,13 +313,11 @@ func (vm *VM) execTempArrLen() error {
 }
 
 func (vm *VM) execTempArrPush() error {
-	top := vm.stack.Top()
-
-	a, err := top.Pop()
+	a, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
-	e, err := top.Pop()
+	e, err := vm.stack.Pop()
 	if err != nil {
 		return err
 	}
diff --git a/pkg/lang/vm/stack.go b/pkg/lang/vm/stack.go
index 9cf7db8..937be32 100644
--- a/pkg/lang/vm/stack.go
+++ b/pkg/lang/vm/stack.go
@@ -4,70 +4,96 @@ import (
 	"jinx/pkg/lang/vm/value"
 )
 
-type CallStack []*LocalStack
-
-func NewCallStack() CallStack {
-	return []*LocalStack{{}}
+type Stack struct {
+	data  []value.Value
+	calls []callFrame
 }
 
-func (cs *CallStack) Push() error {
-	if len(*cs) > 1000 {
-		return ErrCallStackOverflow
+func NewStack() Stack {
+	data := make([]value.Value, 0, 64)
+	calls := make([]callFrame, 0, 8)
+
+	calls = append(calls, callFrame{
+		pc:       0,
+		returnPc: 0,
+		base:     0,
+	})
+
+	return Stack{
+		data:  data,
+		calls: calls,
 	}
+}
 
-	*cs = append(*cs, &LocalStack{})
-	return nil
+func (stack *Stack) Push(value value.Value) {
+	stack.data = append(stack.data, value)
 }
 
-func (cs *CallStack) Pop() error {
-	if len(*cs) <= 1 {
-		return ErrCantPopRootFrame
+func (stack *Stack) Pop() (value.Value, error) {
+	if stack.IsEmpty() || stack.ReachedBaseOfCall() {
+		return value.Value{}, ErrStackUnderflow
 	}
 
-	*cs = (*cs)[:len(*cs)-1]
-	return nil
+	v, err := stack.Top()
+	if err != nil {
+		return value.Value{}, err
+	}
+
+	stack.data = stack.data[:stack.Len()-1]
+	return v, nil
 }
 
-func (cs *CallStack) Top() *LocalStack {
-	return (*cs)[len(*cs)-1]
+func (stack *Stack) Local(offset int) (value.Value, error) {
+	if stack.ReachedBaseOfCall() {
+		return value.Value{}, ErrStackUnderflow
+	}
+
+	if offset < 0 || offset >= stack.Len() {
+		return value.Value{}, ErrLocalIndexOutOfBounds{Index: offset, Len: stack.Len()}
+	}
+
+	base := stack.TopCall().base
+	return stack.data[base+offset], nil
 }
 
-func (cs *CallStack) Prev() (*LocalStack, error) {
-	if len(*cs) <= 1 {
-		return nil, ErrNoPreviousCallFrame
+func (stack *Stack) Top() (value.Value, error) {
+	if stack.IsEmpty() || stack.ReachedBaseOfCall() {
+		return value.Value{}, ErrStackUnderflow
 	}
 
-	return (*cs)[len(*cs)-2], nil
+	return stack.data[stack.Len()-1], nil
 }
 
-type LocalStack []value.Value
+func (stack *Stack) ShiftTopCallBase(by int) error {
+	call := stack.TopCall()
+	newBase := call.base - by
 
-func (ls *LocalStack) Push(v value.Value) error {
-	if len(*ls) > 1000 {
-		return ErrLocalStackOverflow
+	if newBase < 0 {
+		return ErrCallBaseCantBeNegative
 	}
 
-	*ls = append(*ls, v)
+	call.base = newBase
 	return nil
 }
 
-func (ls *LocalStack) Pop() (value.Value, error) {
-	if len(*ls) == 0 {
-		return value.Value{}, ErrCallFrameEmpty
-	}
+func (stack *Stack) Len() int {
+	return len(stack.data)
+}
 
-	v := (*ls)[len(*ls)-1]
-	*ls = (*ls)[:len(*ls)-1]
-	return v, nil
+func (stack *Stack) IsEmpty() bool {
+	return len(stack.data) == 0
 }
 
-func (ls *LocalStack) At(at int) (value.Value, error) {
-	if at >= len(*ls) {
-		return value.Value{}, ErrLocalIndexOutOfBounds{
-			Index: at,
-			Len:   len(*ls),
-		}
-	}
+func (stack *Stack) TopCall() *callFrame {
+	return &stack.calls[len(stack.calls)-1]
+}
+
+func (stack *Stack) ReachedBaseOfCall() bool {
+	return stack.TopCall().base == stack.Len()
+}
 
-	return (*ls)[at], nil
+type callFrame struct {
+	pc       int // Beginning of the called function.
+	returnPc int // Where to return to after the called function returns.
+	base     int // Base of the local variables on the data stack.
 }
diff --git a/pkg/lang/vm/text/compiler_test.go b/pkg/lang/vm/text/compiler_test.go
index cf2f6a9..237e884 100644
--- a/pkg/lang/vm/text/compiler_test.go
+++ b/pkg/lang/vm/text/compiler_test.go
@@ -13,8 +13,7 @@ import (
 
 func TestSimple(t *testing.T) {
 	src := `
-	get_arg
-	get_arg
+	add
 	sub
 	ret
 	`
@@ -24,8 +23,7 @@ func TestSimple(t *testing.T) {
 	require.NoError(t, err)
 
 	parts := [][]byte{
-		opBin(code.OpGetArg),
-		opBin(code.OpGetArg),
+		opBin(code.OpAdd),
 		opBin(code.OpSub),
 		opBin(code.OpRet),
 	}
diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go
index bbcc2d4..df40d8f 100644
--- a/pkg/lang/vm/text/op.go
+++ b/pkg/lang/vm/text/op.go
@@ -15,12 +15,12 @@ var (
 		code.OpPushArray:    "push_array",
 		code.OpPushFunction: "push_function",
 		code.OpPushObject:   "push_object",
+		code.OpShift:        "shift",
 		code.OpDrop:         "drop",
 		code.OpGetGlobal:    "get_global",
 		code.OpGetLocal:     "get_local",
 		code.OpGetMember:    "get_member",
 		code.OpGetMethod:    "get_method",
-		code.OpGetArg:       "get_arg",
 		code.OpGetEnv:       "get_env",
 		code.OpAdd:          "add",
 		code.OpSub:          "sub",
diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go
index 8fdc0d5..a428bfd 100644
--- a/pkg/lang/vm/vm.go
+++ b/pkg/lang/vm/vm.go
@@ -8,19 +8,19 @@ import (
 type VM struct {
 	code  *code.Code
 	pc    int
-	stack CallStack
+	stack Stack
 }
 
 func New(code *code.Code) *VM {
 	return &VM{
 		code:  code,
 		pc:    0,
-		stack: NewCallStack(),
+		stack: NewStack(),
 	}
 }
 
 func (vm *VM) GetResult() (string, error) {
-	res, err := vm.stack.Top().Pop()
+	res, err := vm.stack.Pop()
 	if err != nil {
 		return "", err
 	}
@@ -89,8 +89,13 @@ func (vm *VM) step(op code.Op) (stepDecision, error) {
 	case code.OpPushObject:
 		panic("not implemented")
 
+	case code.OpShift:
+		by, advance := vm.code.GetInt(vm.pc)
+		vm.pc += advance
+
+		err = vm.stack.ShiftTopCallBase(int(by))
 	case code.OpDrop:
-		_, err = vm.stack.Top().Pop()
+		_, err = vm.stack.Pop()
 
 	case code.OpGetGlobal:
 		panic("not implemented")
@@ -103,8 +108,6 @@ func (vm *VM) step(op code.Op) (stepDecision, error) {
 		panic("not implemented")
 	case code.OpGetMethod:
 		panic("not implemented")
-	case code.OpGetArg:
-		vm.execGetArg()
 	case code.OpGetEnv:
 		panic("not implemented")
 
diff --git a/pkg/lang/vm/vm_test.go b/pkg/lang/vm/vm_test.go
index 3622e0b..f87182f 100644
--- a/pkg/lang/vm/vm_test.go
+++ b/pkg/lang/vm/vm_test.go
@@ -10,6 +10,28 @@ import (
 	"github.com/stretchr/testify/require"
 )
 
+func TestSimpleSub(t *testing.T) {
+	src := `
+	push_int 1
+	push_int 2
+	sub
+	`
+
+	test(t, src, "1")
+}
+
+func TestGetLocal(t *testing.T) {
+	src := `
+	push_int 404
+	push_int 1
+	push_int 2
+	add
+	get_local 1
+	`
+
+	test(t, src, "3")
+}
+
 func TestFibonacci(t *testing.T) {
 	src := `
 	# Array stored in local 0
@@ -65,8 +87,12 @@ func TestFibonacci(t *testing.T) {
 	jt @fib_loop
 	`
 
-	bc := compile(t, src)
+	test(t, src, "[1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]")
+}
 
+
+func test(t *testing.T, src string, expected string) {
+	bc := compile(t, src)
 	vm := vm.New(&bc)
 	err := vm.Run()
 	require.NoError(t, err)
@@ -74,7 +100,7 @@ func TestFibonacci(t *testing.T) {
 	res, err := vm.GetResult()
 	require.NoError(t, err)
 
-	require.Equal(t, "[1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89]", res)
+	require.Equal(t, expected, res)
 }
 
 func compile(t *testing.T, src string) code.Code {