diff options
| author | Mel <einebeere@gmail.com> | 2022-05-28 01:22:17 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-05-28 01:22:17 +0000 |
| commit | 0a7700112f82e634a957685bee0cbaa3458f4945 (patch) | |
| tree | 847c397970d7d852bc988a7a01f4625eae443edb /pkg | |
| parent | 83d1dc87f3336d70ccda476627c70c282b7b6e11 (diff) | |
| download | jinx-0a7700112f82e634a957685bee0cbaa3458f4945.tar.zst jinx-0a7700112f82e634a957685bee0cbaa3458f4945.zip | |
Harden VM Mem
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/lang/vm/errors.go | 29 | ||||
| -rw-r--r-- | pkg/lang/vm/exec.go | 103 | ||||
| -rw-r--r-- | pkg/lang/vm/mem/cell.go | 25 | ||||
| -rw-r--r-- | pkg/lang/vm/mem/errors.go | 17 | ||||
| -rw-r--r-- | pkg/lang/vm/mem/mem.go | 74 | ||||
| -rw-r--r-- | pkg/lang/vm/utils.go | 69 | ||||
| -rw-r--r-- | pkg/lang/vm/value/cells.go | 46 | ||||
| -rw-r--r-- | pkg/lang/vm/value/data.go | 80 | ||||
| -rw-r--r-- | pkg/lang/vm/value/value.go | 26 | ||||
| -rw-r--r-- | pkg/lang/vm/vm.go | 31 |
10 files changed, 389 insertions, 111 deletions
diff --git a/pkg/lang/vm/errors.go b/pkg/lang/vm/errors.go index 264dd3a..2f5b56a 100644 --- a/pkg/lang/vm/errors.go +++ b/pkg/lang/vm/errors.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "jinx/pkg/lang/vm/code" + "jinx/pkg/lang/vm/mem" "jinx/pkg/lang/vm/text" "jinx/pkg/lang/vm/value" ) @@ -31,6 +32,8 @@ var ( ErrReachedRootCallFrame = errors.New("reached root call frame") ErrCallBaseCantBeNegative = errors.New("call base cannot be negative") + + ErrEnvNotSet = errors.New("env not set") ) type ErrLocalIndexOutOfBounds struct { @@ -59,6 +62,32 @@ func (e ErrInvalidOp) Error() string { return fmt.Sprintf("invalid opcode: %d", e.Op) } +type ErrUnexpectedMemCell struct { + Ptr mem.Ptr + Expected mem.CellKind + Got mem.CellKind +} + +func (e ErrUnexpectedMemCell) Error() string { + return fmt.Sprintf("unexpected memory cell at %s: expected %v, got %v", e.Ptr.String(), e.Expected, e.Got) +} + +type ErrMemNilCell struct { + Ptr mem.Ptr +} + +func (e ErrMemNilCell) Error() string { + return fmt.Sprintf("found no value at %s", e.Ptr.String()) +} + +type ErrCorruptedMemCell struct { + Ptr mem.Ptr +} + +func (e ErrCorruptedMemCell) Error() string { + return fmt.Sprintf("corrupted memory cell at %s", e.Ptr.String()) +} + // Non-fatal errors, which will later be implemented as catchable exceptions type ErrInvalidOperandType struct { diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go index ee977b7..c9b8615 100644 --- a/pkg/lang/vm/exec.go +++ b/pkg/lang/vm/exec.go @@ -14,8 +14,14 @@ func (vm *VM) execPushFloat(x float64) { vm.stack.Push(value.NewFloat(x)) } -func (vm *VM) execPushString(str string) { - vm.stack.Push(value.NewString(&vm.memory, str)) +func (vm *VM) execPushString(str string) error { + val, err := value.NewString(&vm.memory, str) + if err != nil { + return err + } + + vm.stack.Push(val) + return nil } func (vm *VM) execPushBool(b bool) { @@ -26,8 +32,14 @@ func (vm *VM) execPushNull() { vm.stack.Push(value.NewNull()) } -func (vm *VM) execPushArray() { - vm.stack.Push(value.NewArray(&vm.memory, []value.Value{})) +func (vm *VM) execPushArray() error { + val, err := value.NewArray(&vm.memory, []value.Value{}) + if err != nil { + return err + } + + vm.stack.Push(val) + return nil } func (vm *VM) execPushFunction(pc int) { @@ -47,15 +59,23 @@ func (vm *VM) execGetLocal(offset int) error { } func (vm *VM) execGetEnv(envIndex int) error { - envPtr := vm.stack.CurrentCallEnv() - env := vm.memory.Get(envPtr).(value.Env) + envCell, err := vm.getMemCell(vm.stack.CurrentCallEnv(), mem.CellKindEnv, false) + if err != nil { + return err + } + env := envCell.(value.EnvCell).Get() // First check outlet. outletPtr := env.GetOutlet(envIndex) - outlet := vm.memory.Get(outletPtr) - if outlet != nil { + outletCell, err := vm.getMemCell(outletPtr, mem.CellKindOutlet, true) + if err != nil { + return err + } + + if outletCell != nil { // Outlet is not null, so value escaped. - val := outlet.(value.Value).Clone(&vm.memory) + outlet := outletCell.(value.OutletCell).Get() + val := outlet.Clone(&vm.memory) vm.stack.Push(val) return nil } @@ -79,14 +99,23 @@ func (vm *VM) execSetEnv(envIndex int) error { return err } - envPtr := vm.stack.CurrentCallEnv() - env := vm.memory.Get(envPtr).(value.Env) + envCell, err := vm.getMemCell(vm.stack.CurrentCallEnv(), mem.CellKindEnv, false) + if err != nil { + return err + } + env := envCell.(value.EnvCell).Get() outletPtr := env.GetOutlet(envIndex) - outlet := vm.memory.Get(outletPtr) - if outlet != nil { - outlet.(value.Value).Drop(&vm.memory) - vm.memory.Set(outletPtr, new) + outletCell, err := vm.getMemCell(outletPtr, mem.CellKindOutlet, true) + if err != nil { + return err + } + + if outletCell != nil { + // Outlet is not null, so value escaped. + outlet := outletCell.(value.OutletCell).Get() + outlet.Drop(&vm.memory) + vm.memory.Set(outletPtr, value.OutletCell(new)) return nil } @@ -119,8 +148,12 @@ func (vm *VM) execAddToEnv(localIndex int) error { var envPtr mem.Ptr if fn.Env().IsNull() { // Allocate new Env. - envPtr = vm.memory.Allocate(mem.CellKindEnv) - vm.memory.Set(envPtr, value.NewEnv()) + envPtr, err = vm.memory.Allocate(mem.CellKindEnv) + if err != nil { + return err + } + + vm.memory.Set(envPtr, value.EnvCell(value.NewEnv())) fn = fn.WithEnv(envPtr) } else { envPtr = fn.Env() @@ -133,16 +166,24 @@ func (vm *VM) execAddToEnv(localIndex int) error { } if local.Outlet().IsNull() { - outlet := vm.memory.Allocate(mem.CellKindOutlet) - local = local.WithOutlet(outlet) + outletPtr, err := vm.memory.Allocate(mem.CellKindOutlet) + if err != nil { + return err + } + local = local.WithOutlet(outletPtr) } // Add local to env. stackIndex := vm.stack.LocalToStackIndex(localIndex) - env := vm.memory.Get(envPtr).(value.Env) + envCell, err := vm.getMemCell(envPtr, mem.CellKindEnv, false) + if err != nil { + return err + } + + env := envCell.(value.EnvCell).Get() env.Add(stackIndex, local.Outlet()) - vm.memory.Set(envPtr, env) + vm.memory.Set(envPtr, value.EnvCell(env)) // Push function back onto stack. f = f.WithData(fn) @@ -295,15 +336,22 @@ func (vm *VM) execIndex() error { switch i.Type().Kind { case value.IntType: idx := i.Data().(value.IntData).Get() - len := int64(arr.Len(&vm.memory)) - if idx < 0 || idx >= len { + len, err := arr.Len(&vm.memory) + if err != nil { + return err + } + if idx < 0 || idx >= int64(len) { return ErrArrayIndexOutOfBounds{ Index: int(idx), Len: int(len), } } - val := arr.At(&vm.memory, int(idx)).Clone(&vm.memory) + val, err := arr.At(&vm.memory, int(idx)) + if err != nil { + return err + } + val = val.Clone(&vm.memory) vm.stack.Push(val) default: return ErrInvalidOperandTypes{ @@ -458,8 +506,11 @@ func (vm *VM) execTempArrLen() error { switch a.Type().Kind { case value.ArrayType: arr := a.Data().(value.ArrayData) - len := int64(arr.Len(&vm.memory)) - res := value.NewInt(len) + len, err := arr.Len(&vm.memory) + if err != nil { + return err + } + res := value.NewInt(int64(len)) vm.stack.Push(res) default: return ErrInvalidOperandTypes{ diff --git a/pkg/lang/vm/mem/cell.go b/pkg/lang/vm/mem/cell.go index 3f052d5..fcf9a50 100644 --- a/pkg/lang/vm/mem/cell.go +++ b/pkg/lang/vm/mem/cell.go @@ -13,8 +13,31 @@ const ( CellKindForbidden ) +func (c CellKind) String() string { + switch c { + case CellKindEmpty: + return "empty" + case CellKindString: + return "string" + case CellKindArray: + return "array" + case CellKindEnv: + return "env" + case CellKindOutlet: + return "outlet" + case CellKindForbidden: + return "forbidden" + default: + return "unknown" + } +} + type cell struct { kind CellKind refs int - data any + data CellData +} + +type CellData interface { + DropCell(*Mem) } diff --git a/pkg/lang/vm/mem/errors.go b/pkg/lang/vm/mem/errors.go new file mode 100644 index 0000000..e7a3f8f --- /dev/null +++ b/pkg/lang/vm/mem/errors.go @@ -0,0 +1,17 @@ +package mem + +import "errors" + +var ( + ErrMemOverflow = errors.New("memory overflow, cannot allocate more than 10000 memory cells") + + ErrFatalNonFreeCell = errors.New("non-free cell marked as free") +) + +type ErrInvalidMemAccess struct { + Ptr Ptr +} + +func (e ErrInvalidMemAccess) Error() string { + return "invalid memory access at " + e.Ptr.String() +} diff --git a/pkg/lang/vm/mem/mem.go b/pkg/lang/vm/mem/mem.go index 4cb6fc4..bdcf01f 100644 --- a/pkg/lang/vm/mem/mem.go +++ b/pkg/lang/vm/mem/mem.go @@ -2,7 +2,7 @@ package mem type Mem struct { cells []cell - free []int + free []Ptr } func New() Mem { @@ -13,69 +13,101 @@ func New() Mem { return Mem{ cells: cells, - free: make([]int, 0), + free: make([]Ptr, 0), } } -func (m *Mem) Allocate(kind CellKind) Ptr { +func (m *Mem) Allocate(kind CellKind) (Ptr, error) { if len(m.free) > 0 { idx := m.free[len(m.free)-1] m.free = m.free[:len(m.free)-1] if m.cells[idx].kind != CellKindEmpty { - panic("invalid free cell") + return NullPtr, ErrFatalNonFreeCell } m.cells[idx].kind = kind - return Ptr(idx) + return Ptr(idx), nil } else { + if len(m.cells) > 10000 { + return NullPtr, ErrMemOverflow + } + idx := len(m.cells) m.cells = append(m.cells, cell{kind: kind, refs: 1}) - return Ptr(idx) + return Ptr(idx), nil } } -func (m *Mem) Set(ptr Ptr, v any) { - if ptr >= Ptr(len(m.cells)) { - panic("out of bounds") +func (m *Mem) Set(ptr Ptr, v CellData) error { + if err := m.validPtr(ptr); err != nil { + return err } m.cells[ptr].data = v + return nil } -func (m *Mem) Get(ptr Ptr) any { - if ptr >= Ptr(len(m.cells)) { - panic("out of bounds") +func (m *Mem) Get(ptr Ptr) (CellData, error) { + if err := m.validPtr(ptr); err != nil { + return nil, err } - return m.cells[ptr].data + return m.cells[ptr].data, nil } func (m *Mem) Is(ptr Ptr, kind CellKind) bool { if ptr >= Ptr(len(m.cells)) { - panic("out of bounds") + return false } return m.cells[ptr].kind == kind } -func (m *Mem) Retain(ptr Ptr) { +func (m *Mem) Kind(ptr Ptr) CellKind { if ptr >= Ptr(len(m.cells)) { - panic("out of bounds") + return CellKindForbidden + } + + return m.cells[ptr].kind +} + +func (m *Mem) Retain(ptr Ptr) error { + if err := m.validPtr(ptr); err != nil { + return err } m.cells[ptr].refs++ + return nil } -func (m *Mem) Release(ptr Ptr) { - if ptr >= Ptr(len(m.cells)) { - panic("out of bounds") +func (m *Mem) Release(ptr Ptr) error { + if err := m.validPtr(ptr); err != nil { + return err } m.cells[ptr].refs-- if m.cells[ptr].refs == 0 { - m.cells[ptr].kind = CellKindEmpty - m.free = append(m.free, int(ptr)) + c := m.cells[ptr].data + c.DropCell(m) + + m.cells[ptr] = cell{} + m.free = append(m.free, ptr) } + + return nil +} + +func (m *Mem) validPtr(ptr Ptr) error { + if ptr >= Ptr(len(m.cells)) { + return ErrInvalidMemAccess{ptr} + } + + kind := m.cells[ptr].kind + if kind == CellKindForbidden || kind == CellKindEmpty { + return ErrInvalidMemAccess{ptr} + } + + return nil } diff --git a/pkg/lang/vm/utils.go b/pkg/lang/vm/utils.go new file mode 100644 index 0000000..95c0298 --- /dev/null +++ b/pkg/lang/vm/utils.go @@ -0,0 +1,69 @@ +package vm + +import ( + "jinx/pkg/lang/vm/mem" + "jinx/pkg/lang/vm/value" +) + +func (vm *VM) popAndDrop() (value.Value, error) { + v, err := vm.stack.Pop() + if err != nil { + return value.Value{}, err + } + v.Drop(&vm.memory) + return v, nil +} + +func (vm *VM) popCallAndDrop() (int, error) { + envPtr := vm.stack.CurrentCallEnv() + vm.memory.Release(envPtr) + + for !vm.stack.ReachedBaseOfCall() { + _, err := vm.popAndDrop() + if err != nil { + return 0, err + } + } + + return vm.stack.PopCall() +} + +func (vm *VM) getMemCell(ptr mem.Ptr, kind mem.CellKind, allowNil bool) (mem.CellData, error) { + if ptr.IsNull() { + return nil, ErrEnvNotSet + } + + if !vm.memory.Is(ptr, kind) { + return nil, ErrUnexpectedMemCell{Ptr: ptr, Expected: mem.CellKindEnv, Got: vm.memory.Kind(ptr)} + } + + cell, err := vm.memory.Get(ptr) + if err != nil { + return nil, err + } + + if cell == nil { + if allowNil { + return nil, nil + } + return nil, ErrMemNilCell{Ptr: ptr} + } + + ok := false + switch kind { + case mem.CellKindString: + _, ok = cell.(value.StringCell) + case mem.CellKindArray: + _, ok = cell.(value.ArrayCell) + case mem.CellKindEnv: + _, ok = cell.(value.EnvCell) + case mem.CellKindOutlet: + _, ok = cell.(value.OutletCell) + } + + if !ok { + return nil, ErrCorruptedMemCell{Ptr: ptr} + } + + return cell, nil +} diff --git a/pkg/lang/vm/value/cells.go b/pkg/lang/vm/value/cells.go new file mode 100644 index 0000000..17a0916 --- /dev/null +++ b/pkg/lang/vm/value/cells.go @@ -0,0 +1,46 @@ +package value + +import "jinx/pkg/lang/vm/mem" + +type ArrayCell []Value + +func (a ArrayCell) DropCell(m *mem.Mem) { + for _, v := range a { + v.Drop(m) + } +} + +func (a ArrayCell) Get() []Value { + return a +} + +type StringCell string + +func (s StringCell) DropCell(m *mem.Mem) { +} + +func (s StringCell) Get() string { + return string(s) +} + +type OutletCell Value + +func (o OutletCell) DropCell(m *mem.Mem) { + Value(o).Drop(m) +} + +func (o OutletCell) Get() Value { + return Value(o) +} + +type EnvCell Env + +func (e EnvCell) DropCell(m *mem.Mem) { + for _, v := range e.references { + m.Release(v.outlet) + } +} + +func (e EnvCell) Get() Env { + return Env(e) +} diff --git a/pkg/lang/vm/value/data.go b/pkg/lang/vm/value/data.go index a49753e..2e3b3e6 100644 --- a/pkg/lang/vm/value/data.go +++ b/pkg/lang/vm/value/data.go @@ -8,7 +8,7 @@ import ( ) type Data interface { - String(*mem.Mem) string + String(*mem.Mem) (string, error) } type IntData int64 @@ -17,8 +17,8 @@ func (i IntData) Get() int64 { return int64(i) } -func (i IntData) String(_ *mem.Mem) string { - return strconv.FormatInt(int64(i), 10) +func (i IntData) String(_ *mem.Mem) (string, error) { + return strconv.FormatInt(int64(i), 10), nil } type FloatData float64 @@ -27,17 +27,20 @@ func (f FloatData) Get() float64 { return float64(f) } -func (f FloatData) String(_ *mem.Mem) string { - return strconv.FormatFloat(float64(f), 'f', -1, 64) +func (f FloatData) String(_ *mem.Mem) (string, error) { + return strconv.FormatFloat(float64(f), 'f', -1, 64), nil } type StringData struct { data mem.Ptr } -func (s StringData) String(m *mem.Mem) string { - data := m.Get(s.data) - return "\"" + data.(string) + "\"" +func (s StringData) String(m *mem.Mem) (string, error) { + if data, err := m.Get(s.data); err == nil { + return "\"" + data.(StringCell).Get() + "\"", nil + } else { + return "", err + } } type BoolData bool @@ -46,16 +49,21 @@ func (b BoolData) Get() bool { return bool(b) } -func (b BoolData) String(_ *mem.Mem) string { - return strconv.FormatBool(bool(b)) +func (b BoolData) String(_ *mem.Mem) (string, error) { + return strconv.FormatBool(bool(b)), nil } type ArrayData struct { data mem.Ptr } -func (a ArrayData) String(m *mem.Mem) string { - arr := m.Get(a.data).([]Value) +func (a ArrayData) String(m *mem.Mem) (string, error) { + val, err := m.Get(a.data) + if err != nil { + return "", err + } + + arr := val.(ArrayCell).Get() builder := strings.Builder{} builder.WriteString("[") @@ -63,29 +71,45 @@ func (a ArrayData) String(m *mem.Mem) string { if i > 0 { builder.WriteString(", ") } - builder.WriteString(v.Data().String(m)) + if s, err := v.Data().String(m); err == nil { + builder.WriteString(s) + } else { + return "", err + } } builder.WriteString("]") - return builder.String() + return builder.String(), nil } -func (a ArrayData) Len(m *mem.Mem) int { - data := m.Get(a.data) - arr := data.([]Value) - return len(arr) +func (a ArrayData) Len(m *mem.Mem) (int, error) { + data, err := m.Get(a.data) + if err != nil { + return 0, err + } + arr := data.(ArrayCell).Get() + return len(arr), nil } -func (a ArrayData) At(m *mem.Mem, i int) Value { - data := m.Get(a.data) - arr := data.([]Value) - return arr[i] +func (a ArrayData) At(m *mem.Mem, i int) (Value, error) { + data, err := m.Get(a.data) + if err != nil { + return Value{}, err + } + arr := data.(ArrayCell).Get() + return arr[i], nil } -func (a ArrayData) Push(m *mem.Mem, v Value) { - data := m.Get(a.data) - arr := data.([]Value) +func (a ArrayData) Push(m *mem.Mem, v Value) error { + data, err := m.Get(a.data) + if err != nil { + return err + } + + arr := data.(ArrayCell).Get() arr = append(arr, v) - m.Set(a.data, arr) + m.Set(a.data, ArrayCell(arr)) + + return nil } type NullData struct{} @@ -107,8 +131,8 @@ func (f FunctionData) WithEnv(env mem.Ptr) FunctionData { return FunctionData{pc: f.pc, env: env} } -func (f FunctionData) String(_ *mem.Mem) string { - return fmt.Sprintf("<fn %d>", f.pc) +func (f FunctionData) String(_ *mem.Mem) (string, error) { + return fmt.Sprintf("<fn %d>", f.pc), nil } type ObjectData struct{} // TODO diff --git a/pkg/lang/vm/value/value.go b/pkg/lang/vm/value/value.go index d19bbd6..f98740e 100644 --- a/pkg/lang/vm/value/value.go +++ b/pkg/lang/vm/value/value.go @@ -20,13 +20,17 @@ func NewFloat(x float64) Value { return Value{t: t, d: FloatData(x)} } -func NewString(m *mem.Mem, str string) Value { +func NewString(m *mem.Mem, str string) (Value, error) { t := Type{Kind: StringType} - ptr := m.Allocate(mem.CellKindString) - m.Set(ptr, str) + ptr, err := m.Allocate(mem.CellKindString) + if err != nil { + return Value{}, err + } + + m.Set(ptr, StringCell(str)) - return Value{t: t, d: StringData{data: ptr}} + return Value{t: t, d: StringData{data: ptr}}, nil } func NewBool(b bool) Value { @@ -34,13 +38,17 @@ func NewBool(b bool) Value { return Value{t: t, d: BoolData(b)} } -func NewArray(m *mem.Mem, arr []Value) Value { +func NewArray(m *mem.Mem, arr []Value) (Value, error) { t := Type{Kind: ArrayType} - ptr := m.Allocate(mem.CellKindArray) - m.Set(ptr, arr) + ptr, err := m.Allocate(mem.CellKindArray) + if err != nil { + return Value{}, err + } + + m.Set(ptr, ArrayCell(arr)) - return Value{t: t, d: ArrayData{data: ptr}} + return Value{t: t, d: ArrayData{data: ptr}}, nil } func NewNull() Value { @@ -99,7 +107,7 @@ func (v Value) Clone(m *mem.Mem) Value { func (v Value) Drop(m *mem.Mem) { // If value has an outlet, don't drop it and instead move it to the outlet. if !v.outlet.IsNull() { - m.Set(v.outlet, v) + m.Set(v.outlet, OutletCell(v)) return } diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go index c933af4..9c17a3f 100644 --- a/pkg/lang/vm/vm.go +++ b/pkg/lang/vm/vm.go @@ -3,7 +3,6 @@ package vm import ( "jinx/pkg/lang/vm/code" "jinx/pkg/lang/vm/mem" - "jinx/pkg/lang/vm/value" ) type VM struct { @@ -28,8 +27,11 @@ func (vm *VM) GetResult() (string, error) { return "", err } - str := res.Data().String(&vm.memory) - return str, nil + if str, err := res.Data().String(&vm.memory); err == nil { + return str, nil + } else { + return "", err + } } func (vm *VM) Run() error { @@ -169,26 +171,3 @@ func (vm *VM) step(op code.Op) (stepDecision, error) { return stepDecisionContinue, err } - -func (vm *VM) popAndDrop() (value.Value, error) { - v, err := vm.stack.Pop() - if err != nil { - return value.Value{}, err - } - v.Drop(&vm.memory) - return v, nil -} - -func (vm *VM) popCallAndDrop() (int, error) { - envPtr := vm.stack.CurrentCallEnv() - vm.memory.Release(envPtr) - - for !vm.stack.ReachedBaseOfCall() { - _, err := vm.popAndDrop() - if err != nil { - return 0, err - } - } - - return vm.stack.PopCall() -} |
