diff options
| -rw-r--r-- | pkg/lang/vm/text/decompiler.go | 137 | ||||
| -rw-r--r-- | pkg/lang/vm/text/decompiler_test.go | 75 |
2 files changed, 212 insertions, 0 deletions
diff --git a/pkg/lang/vm/text/decompiler.go b/pkg/lang/vm/text/decompiler.go new file mode 100644 index 0000000..c8922ec --- /dev/null +++ b/pkg/lang/vm/text/decompiler.go @@ -0,0 +1,137 @@ +package text + +import ( + "encoding/binary" + "fmt" + "jinx/pkg/lang/vm/code" + "jinx/pkg/libs/rangemap" + "math" + "strconv" + "strings" +) + +type Decompiler struct { + c code.Code + + pcToLine rangemap.RangeMap[int] +} + +func NewDecompiler(c code.Code) *Decompiler { + return &Decompiler{ + c: c, + pcToLine: rangemap.New[int](), + } +} + +func (d *Decompiler) Decompile() string { + lines := make([]string, 0) + bc := d.c.Code() + + for len(bc) != 0 { + line, rest := d.decompileInstruction(bc) + bc = rest + + d.pcToLine.AppendToLast(d.c.Len()-len(bc), len(lines)) + + lines = append(lines, line) + } + + return strings.Join(lines, "\n") +} + +func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) { + op := code.Op(bc[0]) + opString := OpToString(op) + if opString == "unknown" { + return fmt.Sprintf("unknown(%x)", bc[0]), bc[1:] + } + + switch op { + // Operations that take no arguments. + case code.OpNop, + code.OpHalt, + code.OpPushTrue, + code.OpPushFalse, + code.OpPushNull, + code.OpPushArray, + code.OpPushObject, + code.OpDrop, + code.OpAnchorType, + code.OpAdd, + code.OpSub, + code.OpMod, + code.OpIndex, + code.OpLte, + code.OpRet, + code.OpTempArrLen, + code.OpTempArrPush: + return opString, bc[1:] + + // Operations that take an int. + case code.OpPushInt, + code.OpGetLocal, + code.OpSetLocal, + code.OpGetEnv, + code.OpSetEnv, + code.OpAddToEnv, + code.OpCall: + i, rest := d.decompileInt(bc[1:]) + return fmt.Sprintf("%s %s", opString, i), rest + + // Operations that take a float. + case code.OpPushFloat: + f, rest := d.decompileFloat(bc[1:]) + return fmt.Sprintf("%s %s", opString, f), rest + + // Operations that take a string. + case code.OpPushString, + code.OpPushType, + code.OpGetGlobal, + code.OpSetMember, + code.OpGetMember: + s, rest := d.decompileString(bc[1:]) + return fmt.Sprintf("%s %s", opString, s), rest + + // Operations that take a pc, belonging to a function. + case code.OpPushFunction: + // TODO: Add function labels to output and give them names. + fallthrough + + // Operations that take a pc, belonging to a non-call jump. (when branching) + case code.OpJmp, + code.OpJt, + code.OpJf: + // TODO: Add jump labels to output and give them names. + i, rest := d.decompileInt(bc[1:]) + return fmt.Sprintf("%s @%s", opString, i), rest + } + + panic("decompiler can't decompile op: " + opString) +} + +func (d *Decompiler) decompileInt(bc code.Raw) (string, code.Raw) { + i := binary.LittleEndian.Uint64(bc[:8]) + return strconv.FormatInt(int64(i), 10), bc[8:] +} + +func (d *Decompiler) decompileFloat(bc code.Raw) (string, code.Raw) { + i := binary.LittleEndian.Uint64(bc[:8]) + f := math.Float64frombits(i) + return strconv.FormatFloat(f, 'f', -1, 64), bc[8:] +} + +func (d *Decompiler) decompileString(bc code.Raw) (string, code.Raw) { + var buf strings.Builder + buf.WriteString("\"") + end := 0 + for i, b := range bc { + if b == 0 { + end = i + break + } + buf.WriteByte(b) + } + buf.WriteString("\"") + + return buf.String(), bc[end+1:] +} diff --git a/pkg/lang/vm/text/decompiler_test.go b/pkg/lang/vm/text/decompiler_test.go new file mode 100644 index 0000000..01e49b0 --- /dev/null +++ b/pkg/lang/vm/text/decompiler_test.go @@ -0,0 +1,75 @@ +package text_test + +import ( + "jinx/pkg/lang/vm/text" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecompileSimple(t *testing.T) { + src := ` + push_int 1 + push_int 2 + + add + ` + + expected := ` + push_int 1 + push_int 2 + add + ` + + test(t, src, expected) +} + +func TestDecompileValues(t *testing.T) { + src := ` + push_int 1 + push_string "foo" + push_float 3.14 + push_function @foo + halt + @foo: + push_int 1 + ret + ` + + // No label names yet. + expected := ` + push_int 1 + push_string "foo" + push_float 3.14 + push_function @33 + halt + push_int 1 + ret + ` + + test(t, src, expected) +} + +func test(t *testing.T, code string, expected string) { + expectedLines := strings.Split(expected, "\n") + trimmedExpectedLines := make([]string, 0, len(expectedLines)) + for _, line := range expectedLines { + trimmedLine := strings.TrimSpace(line) + if trimmedLine == "" { + continue + } + + trimmedExpectedLines = append(trimmedExpectedLines, trimmedLine) + } + + trimmedExpected := strings.Join(trimmedExpectedLines, "\n") + + comp := text.NewCompiler(strings.NewReader(code)) + resCompiled, err := comp.Compile() + require.NoError(t, err) + + decomp := text.NewDecompiler(resCompiled) + resDecompiled := decomp.Decompile() + require.Equal(t, trimmedExpected, resDecompiled) +} |
