From a8e90423a6d29d7241b7d1f5ab60d607edffef9f Mon Sep 17 00:00:00 2001 From: Mel Date: Sat, 28 May 2022 20:56:10 +0000 Subject: Harden Mem further and add tests --- pkg/lang/vm/mem/mem.go | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) (limited to 'pkg/lang/vm/mem/mem.go') diff --git a/pkg/lang/vm/mem/mem.go b/pkg/lang/vm/mem/mem.go index a645ee2..0f15743 100644 --- a/pkg/lang/vm/mem/mem.go +++ b/pkg/lang/vm/mem/mem.go @@ -10,6 +10,7 @@ type Mem interface { Retain(ptr Ptr) error Release(ptr Ptr) error + RefCount(ptr Ptr) int } type memImpl struct { @@ -30,15 +31,20 @@ func New() Mem { } func (m *memImpl) Allocate(kind CellKind) (Ptr, error) { + if kind == CellKindForbidden || kind == CellKindEmpty { + return NullPtr, ErrInvalidCellKind{kind} + } + if len(m.free) > 0 { idx := m.free[len(m.free)-1] m.free = m.free[:len(m.free)-1] if m.cells[idx].kind != CellKindEmpty { - return NullPtr, ErrFatalNonFreeCell + // This should never happen. + panic("cell marked as free was not empty") } - m.cells[idx].kind = kind + m.cells[idx] = cell{kind: kind, refs: 1} return Ptr(idx), nil } else { if len(m.cells) > 10000 { @@ -57,6 +63,10 @@ func (m *memImpl) Set(ptr Ptr, v CellData) error { return err } + if m.cells[ptr].kind != v.MatchingCellKind() { + return ErrDifferingCellKind{Ptr: ptr, Expected: m.cells[ptr].kind, Got: v.MatchingCellKind()} + } + m.cells[ptr].data = v return nil } @@ -71,7 +81,7 @@ func (m *memImpl) Get(ptr Ptr) (CellData, error) { func (m *memImpl) Is(ptr Ptr, kind CellKind) bool { if ptr >= Ptr(len(m.cells)) { - return false + return kind == CellKindEmpty } return m.cells[ptr].kind == kind @@ -79,7 +89,7 @@ func (m *memImpl) Is(ptr Ptr, kind CellKind) bool { func (m *memImpl) Kind(ptr Ptr) CellKind { if ptr >= Ptr(len(m.cells)) { - return CellKindForbidden + return CellKindEmpty } return m.cells[ptr].kind @@ -111,6 +121,14 @@ func (m *memImpl) Release(ptr Ptr) error { return nil } +func (m *memImpl) RefCount(ptr Ptr) int { + if err := m.validPtr(ptr); err != nil { + return 0 + } + + return m.cells[ptr].refs +} + func (m *memImpl) validPtr(ptr Ptr) error { if ptr >= Ptr(len(m.cells)) { return ErrInvalidMemAccess{ptr} -- cgit 1.4.1