diff options
Diffstat (limited to 'pkg/lang/vm')
| -rw-r--r-- | pkg/lang/vm/code/op.go | 2 | ||||
| -rw-r--r-- | pkg/lang/vm/code/pos.go | 13 | ||||
| -rw-r--r-- | pkg/lang/vm/errors.go | 30 | ||||
| -rw-r--r-- | pkg/lang/vm/exec.go | 82 | ||||
| -rw-r--r-- | pkg/lang/vm/mem/cell.go | 4 | ||||
| -rw-r--r-- | pkg/lang/vm/stack/errors.go | 1 | ||||
| -rw-r--r-- | pkg/lang/vm/stack/stack.go | 59 | ||||
| -rw-r--r-- | pkg/lang/vm/text/decompiler.go | 2 | ||||
| -rw-r--r-- | pkg/lang/vm/text/op.go | 2 | ||||
| -rw-r--r-- | pkg/lang/vm/utils.go | 27 | ||||
| -rw-r--r-- | pkg/lang/vm/value/cells.go | 21 | ||||
| -rw-r--r-- | pkg/lang/vm/value/data.go | 13 | ||||
| -rw-r--r-- | pkg/lang/vm/value/value.go | 5 | ||||
| -rw-r--r-- | pkg/lang/vm/vm.go | 163 | ||||
| -rw-r--r-- | pkg/lang/vm/vm_test.go | 43 |
15 files changed, 374 insertions, 93 deletions
diff --git a/pkg/lang/vm/code/op.go b/pkg/lang/vm/code/op.go index 1ce08f7..b4c172a 100644 --- a/pkg/lang/vm/code/op.go +++ b/pkg/lang/vm/code/op.go @@ -19,7 +19,9 @@ const ( OpDrop + OpAddGlobal OpGetGlobal + OpSetGlobal OpGetLocal OpSetLocal OpGetMember diff --git a/pkg/lang/vm/code/pos.go b/pkg/lang/vm/code/pos.go new file mode 100644 index 0000000..90deb33 --- /dev/null +++ b/pkg/lang/vm/code/pos.go @@ -0,0 +1,13 @@ +package code + +type Pos struct { + Module int + PC int +} + +func NewPos(module int, pc int) Pos { + return Pos{ + Module: module, + PC: pc, + } +} \ No newline at end of file diff --git a/pkg/lang/vm/errors.go b/pkg/lang/vm/errors.go index bfc8a34..0dd73e7 100644 --- a/pkg/lang/vm/errors.go +++ b/pkg/lang/vm/errors.go @@ -9,16 +9,16 @@ import ( ) type Error struct { - Pc int + Pos code.Pos Line int Err error } func (e Error) Error() string { if e.Line == -1 { - return fmt.Sprintf("vm error at pc %d, unknown line: %v", e.Pc, e.Err) + return fmt.Sprintf("vm error in module '%d' at pc %d, unknown line: %v", e.Pos.Module, e.Pos.PC, e.Err) } - return fmt.Sprintf("vm error at pc %d, line %d: %v", e.Pc, e.Line, e.Err) + return fmt.Sprintf("vm error in module '%d' at pc %d, line %d: %v", e.Pos.Module, e.Pos.PC, e.Line, e.Err) } // Fatal errors @@ -111,3 +111,27 @@ type ErrWrongNumberOfArguments struct { func (e ErrWrongNumberOfArguments) Error() string { return fmt.Sprintf("wrong number of arguments: needed %d, got %d", e.Needed, e.Got) } + +type ErrCantAddGlobalFromMain struct { + GlobalName string +} + +func (e ErrCantAddGlobalFromMain) Error() string { + return fmt.Sprintf("can't export '%s' from main module", e.GlobalName) +} + +type ErrGlobalAlreadyExists struct { + GlobalName string +} + +func (e ErrGlobalAlreadyExists) Error() string { + return fmt.Sprintf("global '%s' already exists", e.GlobalName) +} + +type ErrNoSuchGlobal struct { + GlobalName string +} + +func (e ErrNoSuchGlobal) Error() string { + return fmt.Sprintf("no such global '%s'", e.GlobalName) +} diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go index c5b0539..5a8eb05 100644 --- a/pkg/lang/vm/exec.go +++ b/pkg/lang/vm/exec.go @@ -44,7 +44,7 @@ func (vm *VM) execPushArray() error { func (vm *VM) execPushFunction(pc int) { // TODO: Make push ops into functions, where the argCount can be passed. - vm.stack.Push(value.NewFunction(pc, 0)) + vm.stack.Push(value.NewFunction(code.NewPos(vm.module(), pc), 0)) } func (vm *VM) execPushObject() error { @@ -57,6 +57,72 @@ func (vm *VM) execPushObject() error { return nil } +func (vm *VM) execAddGlobal(name string) error { + if !vm.canAddGlobals { + return ErrCantAddGlobalFromMain{GlobalName: name} + } + + if _, ok := vm.globals[name]; ok { + return ErrGlobalAlreadyExists{GlobalName: name} + } + + v, err := vm.stack.Pop() + if err != nil { + return err + } + + globalPtr, err := vm.memory.Allocate(mem.CellKindGlobal) + if err != nil { + return err + } + + globalCell := value.GlobalCell(v) + if err := vm.memory.Set(globalPtr, globalCell); err != nil { + return err + } + + vm.globals[name] = globalPtr + + return nil +} + +func (vm *VM) execGetGlobal(name string) error { + ptr, ok := vm.globals[name] + if !ok { + return ErrNoSuchGlobal{GlobalName: name} + } + + cell, err := vm.getMemCell(ptr, mem.CellKindGlobal, false) + if err != nil { + return err + } + + v := cell.(value.GlobalCell).Get() + v = v.Clone(vm.memory) + + vm.stack.Push(v) + return nil +} + +func (vm *VM) execSetGlobal(name string) error { + ptr, ok := vm.globals[name] + if !ok { + return ErrNoSuchGlobal{GlobalName: name} + } + + new, err := vm.stack.Pop() + if err != nil { + return err + } + + globalCell := value.GlobalCell(new) + if err := vm.memory.Set(ptr, globalCell); err != nil { + return err + } + + return nil +} + func (vm *VM) execGetLocal(offset int) error { local, err := vm.stack.Local(int(offset)) if err != nil { @@ -165,7 +231,7 @@ func (vm *VM) execGetMember(name string) error { return err } - method := value.NewFunction(0, 0).WithData(methodData.WithEnv(envPtr)) + method := value.NewFunction(code.Pos{}, 0).WithData(methodData.WithEnv(envPtr)) // method = method.Clone(vm.memory) will only be necessary when we support methods with environments. vm.stack.Push(method) @@ -219,7 +285,7 @@ func (vm *VM) execGetMember(name string) error { member = member.WithEnv(envPtr) - val := value.NewFunction(0, 0).WithData(member) + val := value.NewFunction(code.Pos{}, 0).WithData(member) vm.stack.Push(val) parent.Drop(vm.memory) @@ -781,7 +847,7 @@ func (vm *VM) execCall(argCount uint) error { } } - if err = vm.stack.PushCall(fn.Pc(), vm.pc, fn.Env()); err != nil { + if err = vm.stack.PushCall(fn.Pos(), vm.pos, fn.Env()); err != nil { return err } @@ -806,7 +872,7 @@ func (vm *VM) execCall(argCount uint) error { vm.stack.Push(arg) } - vm.pc = fn.Pc() + vm.setPos(fn.Pos()) return nil } @@ -821,7 +887,7 @@ func (vm *VM) execJumpIf(pc int, cond bool) error { case value.BoolType: bl := b.Data().(value.BoolData) if bl.Get() == cond { - vm.pc = pc + vm.setPC(pc) } default: var op code.Op @@ -846,13 +912,13 @@ func (vm *VM) execRet() error { return err } - pc, err := vm.popCallAndDrop() + pos, err := vm.popCallAndDrop() if err != nil { return err } vm.stack.Push(returned) - vm.pc = pc + vm.setPos(pos) return nil } diff --git a/pkg/lang/vm/mem/cell.go b/pkg/lang/vm/mem/cell.go index 73aa0c6..85d6c34 100644 --- a/pkg/lang/vm/mem/cell.go +++ b/pkg/lang/vm/mem/cell.go @@ -12,6 +12,8 @@ const ( CellKindEnv CellKindOutlet + CellKindGlobal + CellKindForbidden ) @@ -31,6 +33,8 @@ func (c CellKind) String() string { return "env" case CellKindOutlet: return "outlet" + case CellKindGlobal: + return "global" case CellKindForbidden: return "forbidden" default: diff --git a/pkg/lang/vm/stack/errors.go b/pkg/lang/vm/stack/errors.go index 55ef2ea..55145d5 100644 --- a/pkg/lang/vm/stack/errors.go +++ b/pkg/lang/vm/stack/errors.go @@ -11,6 +11,7 @@ var ( ErrReachedMaxCallDepth = errors.New("reached max call depth (max depth: 1000)") ErrReachedRootCallFrame = errors.New("reached root call frame") + ErrNotAtRootCallFrame = errors.New("not at root call frame") ErrCallBaseCantBeNegative = errors.New("call base cannot be negative") ) diff --git a/pkg/lang/vm/stack/stack.go b/pkg/lang/vm/stack/stack.go index 34ca896..74d4ab8 100644 --- a/pkg/lang/vm/stack/stack.go +++ b/pkg/lang/vm/stack/stack.go @@ -1,6 +1,7 @@ package stack import ( + "jinx/pkg/lang/vm/code" "jinx/pkg/lang/vm/mem" "jinx/pkg/lang/vm/value" ) @@ -17,8 +18,9 @@ type Stack interface { Len() int IsEmpty() bool - PushCall(newPc, returnPc int, env mem.Ptr) error - PopCall() (int, error) + PutRootCall(basePos code.Pos) error + PushCall(newPos, returnPos code.Pos, env mem.Ptr) error + PopCall() (code.Pos, error) CurrentCallEnv() mem.Ptr CallDepth() int @@ -35,12 +37,6 @@ func New() Stack { data := make([]value.Value, 0, 64) calls := make([]callFrame, 0, 8) - calls = append(calls, callFrame{ - pc: 0, - returnPc: 0, - base: 0, - }) - return &stackImpl{ data: data, calls: calls, @@ -106,30 +102,53 @@ func (stack *stackImpl) IsEmpty() bool { return len(stack.data) == 0 } -func (stack *stackImpl) PushCall(newPc, returnPc int, env mem.Ptr) error { +func (stack *stackImpl) PutRootCall(basePos code.Pos) error { + frame := callFrame{ + pos: basePos, + returnPos: basePos, + base: 0, + env: mem.NullPtr, + } + + if stack.CallDepth() == 0 { + stack.calls = append(stack.calls, frame) + + return nil + } + + if stack.CallDepth() != 1 { + return ErrNotAtRootCallFrame + } + + stack.calls[0] = frame + + return nil +} + +func (stack *stackImpl) PushCall(newPos, returnPos code.Pos, env mem.Ptr) error { if stack.CallDepth() == 1000 { return ErrReachedMaxCallDepth } stack.calls = append(stack.calls, callFrame{ - pc: newPc, - returnPc: returnPc, - base: stack.Len(), - env: env, + pos: newPos, + returnPos: returnPos, + base: stack.Len(), + env: env, }) return nil } -func (stack *stackImpl) PopCall() (int, error) { +func (stack *stackImpl) PopCall() (code.Pos, error) { if stack.CallDepth() == 0 { - return 0, ErrReachedRootCallFrame + return code.Pos{}, ErrReachedRootCallFrame } call := stack.topCall() stack.calls = stack.calls[:stack.CallDepth()-1] - return call.returnPc, nil + return call.returnPos, nil } func (stack *stackImpl) CurrentCallEnv() mem.Ptr { @@ -160,8 +179,8 @@ func (stack *stackImpl) topCall() *callFrame { } type callFrame struct { - pc int // Beginning of the called function. - returnPc int // Where to return to after the called function returns. - base int // Base of the local variables on the data stack. - env mem.Ptr // Environment of the called function. + pos code.Pos // Beginning of the called function. + returnPos code.Pos // Where to return to after the called function returns. + base int // Base of the local variables on the data stack. + env mem.Ptr // Environment of the called function. } diff --git a/pkg/lang/vm/text/decompiler.go b/pkg/lang/vm/text/decompiler.go index 2b4e704..8a85d49 100644 --- a/pkg/lang/vm/text/decompiler.go +++ b/pkg/lang/vm/text/decompiler.go @@ -93,7 +93,9 @@ func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) { // Operations that take a string. case code.OpPushString, code.OpPushType, + code.OpAddGlobal, code.OpGetGlobal, + code.OpSetGlobal, code.OpSetMember, code.OpGetMember: s, rest := d.decompileString(bc[1:]) diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go index ce53f50..ff3c568 100644 --- a/pkg/lang/vm/text/op.go +++ b/pkg/lang/vm/text/op.go @@ -17,7 +17,9 @@ var ( code.OpPushObject: "push_object", code.OpPushType: "push_type", code.OpDrop: "drop", + code.OpAddGlobal: "add_global", code.OpGetGlobal: "get_global", + code.OpSetGlobal: "set_global", code.OpGetLocal: "get_local", code.OpSetLocal: "set_local", code.OpGetMember: "get_member", diff --git a/pkg/lang/vm/utils.go b/pkg/lang/vm/utils.go index 00311ef..2b7fa6f 100644 --- a/pkg/lang/vm/utils.go +++ b/pkg/lang/vm/utils.go @@ -1,6 +1,7 @@ package vm import ( + "jinx/pkg/lang/vm/code" "jinx/pkg/lang/vm/mem" "jinx/pkg/lang/vm/value" ) @@ -14,14 +15,14 @@ func (vm *VM) popAndDrop() (value.Value, error) { return v, nil } -func (vm *VM) popCallAndDrop() (int, error) { +func (vm *VM) popCallAndDrop() (code.Pos, error) { envPtr := vm.stack.CurrentCallEnv() vm.memory.Release(envPtr) for !vm.stack.ReachedBaseOfCall() { _, err := vm.popAndDrop() if err != nil { - return 0, err + return code.Pos{}, err } } @@ -63,6 +64,8 @@ func (vm *VM) getMemCell(ptr mem.Ptr, kind mem.CellKind, allowNil bool) (mem.Cel _, ok = cell.(value.TypeCell) case mem.CellKindObject: _, ok = cell.(value.ObjectCell) + case mem.CellKindGlobal: + _, ok = cell.(value.GlobalCell) } if !ok { @@ -71,3 +74,23 @@ func (vm *VM) getMemCell(ptr mem.Ptr, kind mem.CellKind, allowNil bool) (mem.Cel return cell, nil } + +func (vm *VM) module() int { + return vm.pos.Module +} + +func (vm *VM) pc() int { + return vm.pos.PC +} + +func (vm *VM) setPC(pc int) { + vm.pos.PC = pc +} + +func (vm *VM) advancePC(by int) { + vm.pos.PC += by +} + +func (vm *VM) setPos(pos code.Pos) { + vm.pos = pos +} diff --git a/pkg/lang/vm/value/cells.go b/pkg/lang/vm/value/cells.go index 5abb032..19ecfde 100644 --- a/pkg/lang/vm/value/cells.go +++ b/pkg/lang/vm/value/cells.go @@ -1,6 +1,9 @@ package value -import "jinx/pkg/lang/vm/mem" +import ( + "jinx/pkg/lang/vm/code" + "jinx/pkg/lang/vm/mem" +) type ArrayCell []Value @@ -55,7 +58,7 @@ func (t TypeCell) DropCell(m mem.Mem) { typ := t.Get() for _, f := range typ.Methods { // Wrap data in a Value to drop it. - val := NewFunction(0, 0).WithData(f) + val := NewFunction(code.Pos{}, 0).WithData(f) val.Drop(m) } @@ -101,3 +104,17 @@ func (e EnvCell) MatchingCellKind() mem.CellKind { func (e EnvCell) Get() Env { return Env(e) } + +type GlobalCell Value + +func (g GlobalCell) DropCell(m mem.Mem) { + panic("global cell cannot be dropped") +} + +func (g GlobalCell) MatchingCellKind() mem.CellKind { + return mem.CellKindGlobal +} + +func (g GlobalCell) Get() Value { + return Value(g) +} diff --git a/pkg/lang/vm/value/data.go b/pkg/lang/vm/value/data.go index 18e1b3d..6c3d762 100644 --- a/pkg/lang/vm/value/data.go +++ b/pkg/lang/vm/value/data.go @@ -2,6 +2,7 @@ package value import ( "fmt" + "jinx/pkg/lang/vm/code" "jinx/pkg/lang/vm/mem" "strconv" "strings" @@ -127,7 +128,7 @@ func (n NullData) String(_ mem.Mem) (string, error) { } type FunctionData struct { - pc int + pos code.Pos args uint env mem.Ptr native NativeFunc @@ -135,8 +136,8 @@ type FunctionData struct { type NativeFunc func([]Value) (Value, error) -func (f FunctionData) Pc() int { - return f.pc +func (f FunctionData) Pos() code.Pos { + return f.pos } func (f FunctionData) Args() uint { @@ -152,18 +153,18 @@ func (f FunctionData) Native() NativeFunc { } func (f FunctionData) WithEnv(env mem.Ptr) FunctionData { - return FunctionData{pc: f.pc, args: f.args, env: env, native: f.native} + return FunctionData{pos: f.pos, args: f.args, env: env, native: f.native} } func (f FunctionData) WithArgs(args uint) FunctionData { - return FunctionData{pc: f.pc, args: args, env: f.env, native: f.native} + return FunctionData{pos: f.pos, args: args, env: f.env, native: f.native} } func (f FunctionData) String(_ mem.Mem) (string, error) { if f.native != nil { return fmt.Sprintf("<fn(%d) native>", f.args), nil } else { - return fmt.Sprintf("<fn(%d) %d>", f.args, f.pc), nil + return fmt.Sprintf("<fn(%d) %d:%d>", f.args, f.pos.Module, f.pos.PC), nil } } diff --git a/pkg/lang/vm/value/value.go b/pkg/lang/vm/value/value.go index 5dd5012..b051f04 100644 --- a/pkg/lang/vm/value/value.go +++ b/pkg/lang/vm/value/value.go @@ -1,6 +1,7 @@ package value import ( + "jinx/pkg/lang/vm/code" "jinx/pkg/lang/vm/mem" ) @@ -52,8 +53,8 @@ func NewNull() Value { return Value{t: CORE_TYPE_NULL, d: NullData{}} } -func NewFunction(pc int, args uint) Value { - return Value{t: CORE_TYPE_FUNCTION, d: FunctionData{pc: pc, args: args}} +func NewFunction(pos code.Pos, args uint) Value { + return Value{t: CORE_TYPE_FUNCTION, d: FunctionData{pos: pos, args: args}} } func NewNativeFunction(f NativeFunc, args uint) Value { diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go index 8b47915..4a4aa68 100644 --- a/pkg/lang/vm/vm.go +++ b/pkg/lang/vm/vm.go @@ -7,18 +7,28 @@ import ( ) type VM struct { - code *code.Code - pc int + pos code.Pos + + modules []*code.Code + stack stack.Stack memory mem.Mem + + canAddGlobals bool + globals map[string]mem.Ptr } -func New(code *code.Code) *VM { +func New(main *code.Code, deps []*code.Code) *VM { vm := &VM{ - code: code, - pc: 0, + pos: code.NewPos(0, 0), + + modules: append([]*code.Code{main}, deps...), + stack: stack.New(), memory: mem.New(), + + canAddGlobals: false, + globals: make(map[string]mem.Ptr), } if err := vm.setup(); err != nil { @@ -46,14 +56,56 @@ func (vm *VM) GetResult() (string, error) { } func (vm *VM) Run() error { - for vm.pc < vm.code.Len() { - op, advance := vm.code.GetOp(vm.pc) - vm.pc += advance + vm.canAddGlobals = true + for i := 1; i < len(vm.modules); i++ { + if err := vm.executeModule(i); err != nil { + return err + } + + // Drop all calls from the stack after the module has finished + for vm.stack.CallDepth() > 1 { + if _, err := vm.popCallAndDrop(); err != nil { + return err + } + } + + // Drop the root stack values + for !vm.stack.IsEmpty() { + if _, err := vm.popAndDrop(); err != nil { + return err + } + } + } + + vm.canAddGlobals = false + if err := vm.executeModule(0); err != nil { + return err + } + + return nil +} + +func (vm *VM) executeModule(moduleID int) error { + vm.pos = code.NewPos(moduleID, 0) + + if err := vm.stack.PutRootCall(vm.pos); err != nil { + return err + } + + for { + module := vm.modules[vm.module()] - if decision, err := vm.step(op); err != nil { + if vm.pc() >= module.Len() { + return nil + } + + op, advance := module.GetOp(vm.pc()) + vm.advancePC(advance) + + if decision, err := vm.step(module, op); err != nil { return Error{ - Pc: vm.pc, - Line: vm.code.Debug().PCToLine(vm.pc), + Pos: vm.pos, + Line: module.Debug().PCToLine(vm.pc()), Err: err, } } else if decision == stepDecisionHalt { @@ -71,7 +123,7 @@ const ( stepDecisionHalt ) -func (vm *VM) step(op code.Op) (stepDecision, error) { +func (vm *VM) step(module *code.Code, op code.Op) (stepDecision, error) { var err error switch op { case code.OpNop: @@ -80,18 +132,18 @@ func (vm *VM) step(op code.Op) (stepDecision, error) { return stepDecisionHalt, nil case code.OpPushInt: - x, advance := vm.code.GetInt(vm.pc) - vm.pc += advance + x, advance := module.GetInt(vm.pc()) + vm.advancePC(advance) vm.execPushInt(x) case code.OpPushFloat: - x, advance := vm.code.GetFloat(vm.pc) - vm.pc += advance + x, advance := module.GetFloat(vm.pc()) + vm.advancePC(advance) vm.execPushFloat(x) case code.OpPushString: - str, advance := vm.code.GetString(vm.pc) - vm.pc += advance + str, advance := module.GetString(vm.pc()) + vm.advancePC(advance) err = vm.execPushString(str) case code.OpPushNull: @@ -103,60 +155,73 @@ func (vm *VM) step(op code.Op) (stepDecision, error) { case code.OpPushArray: vm.execPushArray() case code.OpPushFunction: - x, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + x, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) vm.execPushFunction(int(x)) case code.OpPushObject: err = vm.execPushObject() case code.OpPushType: - name, advance := vm.code.GetString(vm.pc) - vm.pc += advance + name, advance := module.GetString(vm.pc()) + vm.advancePC(advance) err = vm.execPushType(name) case code.OpDrop: - dropAmount, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + dropAmount, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execDrop(uint(dropAmount)) + case code.OpAddGlobal: + globalName, advance := module.GetString(vm.pc()) + vm.advancePC(advance) + + err = vm.execAddGlobal(globalName) case code.OpGetGlobal: - panic("not implemented") + globalName, advance := module.GetString(vm.pc()) + vm.advancePC(advance) + + err = vm.execGetGlobal(globalName) + case code.OpSetGlobal: + globalName, advance := module.GetString(vm.pc()) + vm.advancePC(advance) + + err = vm.execSetGlobal(globalName) case code.OpGetLocal: - offset, advance := vm.code.GetInt(vm.pc) - vm.pc += advance + offset, advance := module.GetInt(vm.pc()) + vm.advancePC(advance) err = vm.execGetLocal(int(offset)) case code.OpSetLocal: - offset, advance := vm.code.GetInt(vm.pc) - vm.pc += advance + offset, advance := module.GetInt(vm.pc()) + vm.advancePC(advance) err = vm.execSetLocal(int(offset)) case code.OpGetMember: - name, advance := vm.code.GetString(vm.pc) - vm.pc += advance + name, advance := module.GetString(vm.pc()) + vm.advancePC(advance) err = vm.execGetMember(name) case code.OpSetMember: - name, advance := vm.code.GetString(vm.pc) - vm.pc += advance + name, advance := module.GetString(vm.pc()) + vm.advancePC(advance) err = vm.execSetMember(name) case code.OpGetEnv: - envIndex, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + envIndex, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execGetEnv(int(envIndex)) case code.OpSetEnv: - envIndex, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + envIndex, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execSetEnv(int(envIndex)) case code.OpAddToEnv: - local, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + local, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execAddToEnv(int(local)) @@ -164,8 +229,8 @@ func (vm *VM) step(op code.Op) (stepDecision, error) { err = vm.execAnchorType() case code.OpSetArgCount: - argCount, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + argCount, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execSetArgCount(uint(argCount)) @@ -193,21 +258,21 @@ func (vm *VM) step(op code.Op) (stepDecision, error) { case code.OpIndex: err = vm.execIndex() case code.OpCall: - argCount, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + argCount, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execCall(uint(argCount)) case code.OpJmp: - pc, _ := vm.code.GetUint(vm.pc) - vm.pc = int(pc) + pc, _ := module.GetUint(vm.pc()) + vm.setPC(int(pc)) case code.OpJt: - pc, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + pc, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execJumpIf(int(pc), true) case code.OpJf: - pc, advance := vm.code.GetUint(vm.pc) - vm.pc += advance + pc, advance := module.GetUint(vm.pc()) + vm.advancePC(advance) err = vm.execJumpIf(int(pc), false) case code.OpRet: err = vm.execRet() diff --git a/pkg/lang/vm/vm_test.go b/pkg/lang/vm/vm_test.go index ae0b8e6..b25a446 100644 --- a/pkg/lang/vm/vm_test.go +++ b/pkg/lang/vm/vm_test.go @@ -440,9 +440,50 @@ func TestPrimes(t *testing.T) { test(t, src, "[2, 3, 5, 7]") } +func TestModules(t *testing.T) { + mainSrc := ` + get_global "add_one" + get_global "x" + call 1 + halt + ` + + librarySrc1 := ` + push_int 41 + add_global "x" + halt + ` + + librarySrc2 := ` + push_function @add_one + set_arg_count 1 + add_global "add_one" + halt + + @add_one: + get_local 0 + push_int 1 + add + ret + ` + + main := compile(t, mainSrc) + library1 := compile(t, librarySrc1) + library2 := compile(t, librarySrc2) + + vm := vm.New(&main, []*code.Code{&library1, &library2}) + err := vm.Run() + require.NoError(t, err) + + res, err := vm.GetResult() + require.NoError(t, err) + + require.Equal(t, "42", res) +} + func test(t *testing.T, src string, expected string) { bc := compile(t, src) - vm := vm.New(&bc) + vm := vm.New(&bc, nil) err := vm.Run() require.NoError(t, err) |
