diff options
Diffstat (limited to 'pkg')
| -rw-r--r-- | pkg/lang/vm/code/code.go | 10 | ||||
| -rw-r--r-- | pkg/lang/vm/text/compiler.go | 202 | ||||
| -rw-r--r-- | pkg/lang/vm/text/compiler_test.go | 164 | ||||
| -rw-r--r-- | pkg/lang/vm/text/errors.go | 33 | ||||
| -rw-r--r-- | pkg/lang/vm/text/op.go | 56 |
5 files changed, 465 insertions, 0 deletions
diff --git a/pkg/lang/vm/code/code.go b/pkg/lang/vm/code/code.go index ab06d27..bd25fcd 100644 --- a/pkg/lang/vm/code/code.go +++ b/pkg/lang/vm/code/code.go @@ -11,10 +11,20 @@ type Code struct { code []byte } +func New(code []byte) Code { + return Code{ + code: code, + } +} + func (c *Code) Len() int { return len(c.code) } +func (c *Code) Code() []byte { + return c.code +} + func (c *Code) GetOp(at int) (Op, int) { return Op(c.code[at]), 1 } diff --git a/pkg/lang/vm/text/compiler.go b/pkg/lang/vm/text/compiler.go new file mode 100644 index 0000000..2732aea --- /dev/null +++ b/pkg/lang/vm/text/compiler.go @@ -0,0 +1,202 @@ +package text + +import ( + "encoding/binary" + "io" + "jinx/pkg/lang/vm/code" + "jinx/pkg/libs/source" + "math" + "strconv" + "strings" + "unicode" +) + +type Compiler struct { + src source.Walker + codePos int + + labelPositions map[string]int + labelReferences []labelReference +} + +func NewCompiler(src io.Reader) *Compiler { + return &Compiler{ + src: *source.NewWalker(src), + codePos: 0, + + labelPositions: map[string]int{}, + labelReferences: []labelReference{}, + } +} + +func (cpl *Compiler) Compile() (code.Code, error) { + res := []byte{} + + for { + _, eof, err := cpl.src.Peek() + if err != nil { + return code.Code{}, err + } + + if eof { + break + } + + line, err := cpl.compileLine() + if err != nil { + return code.Code{}, err + } + + cpl.codePos += len(line) + res = append(res, line...) + } + + if err := cpl.linkLabels(res); err != nil { + return code.Code{}, err + } + + return code.New(res), nil +} + +func (cpl *Compiler) compileLine() ([]byte, error) { + res := []byte{} + + start, value, err := cpl.splitLine() + if err != nil || start == "" { + return nil, err + } + + // Ignore lines starting with a comment. + if start[0] == '#' { + return nil, nil + } + + // Save the position of the label. + if start[0] == '@' { + label := strings.Trim(start, "@:") + if _, ok := cpl.labelPositions[label]; ok { + return nil, ErrDuplicateLabel{label} + } + cpl.labelPositions[label] = cpl.codePos + return nil, nil + } + + // Find the operator. + op, err := cpl.compileOp(start) + if err != nil { + return nil, err + } + + res = append(res, byte(op)) + + // Find the value, of which there is at most one. + if value != "" { + val, err := cpl.compileValue(value) + if err != nil { + return nil, err + } + + res = append(res, val...) + } + + return res, nil +} + +func (cpl *Compiler) compileOp(str string) (code.Op, error) { + op, err := StringToOp(str) + if err != nil { + return 0, err + } + + return op, nil +} + +func (cpl *Compiler) compileValue(str string) ([]byte, error) { + res := make([]byte, 8) + + // Save label reference. + if str[0] == '@' { + label := strings.Trim(str, "@:") + cpl.labelReferences = append(cpl.labelReferences, labelReference{ + label: label, + at: cpl.codePos + 1, // +1 to skip the opcode. + }) + return res, nil + } + + if unicode.IsDigit(rune(str[0])) || str[0] == '-' { + if strings.Contains(str, ".") { + val, err := strconv.ParseFloat(str, 64) + if err != nil { + return res, err + } + + binary.LittleEndian.PutUint64(res, math.Float64bits(val)) + } else { + val, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return res, err + } + + binary.LittleEndian.PutUint64(res, uint64(val)) + } + + return res, nil + } + + if str[0] == '"' { + str = strings.Trim(str, "\"") + res = []byte(str) + res = append(res, 0) + + return res, nil + } + + return res, ErrInvalidValue{str} +} + +func (cpl *Compiler) splitLine() (string, string, error) { + line := "" + + for { + c, eof, err := cpl.src.Next() + if err != nil { + return "", "", err + } + + if eof || c == '\n' { + break + } + + line += string(c) + } + + fields := strings.Fields(line) + if len(fields) == 0 { + return "", "", nil + } + + start := fields[0] + if len(fields) > 1 { + return start, strings.Join(fields[1:], " "), nil + } + + return start, "", nil +} + +type labelReference struct { + label string + at int +} + +func (cpl *Compiler) linkLabels(code []byte) error { + for _, ref := range cpl.labelReferences { + pos, ok := cpl.labelPositions[ref.label] + if !ok { + return ErrUnkonwnLabel{ref.label} + } + + binary.LittleEndian.PutUint64(code[ref.at:], uint64(pos)) + } + return nil +} diff --git a/pkg/lang/vm/text/compiler_test.go b/pkg/lang/vm/text/compiler_test.go new file mode 100644 index 0000000..cf2f6a9 --- /dev/null +++ b/pkg/lang/vm/text/compiler_test.go @@ -0,0 +1,164 @@ +package text_test + +import ( + "encoding/binary" + "jinx/pkg/lang/vm/code" + "jinx/pkg/lang/vm/text" + "math" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSimple(t *testing.T) { + src := ` + get_arg + get_arg + sub + ret + ` + + c := text.NewCompiler(strings.NewReader(src)) + res, err := c.Compile() + require.NoError(t, err) + + parts := [][]byte{ + opBin(code.OpGetArg), + opBin(code.OpGetArg), + opBin(code.OpSub), + opBin(code.OpRet), + } + + require.Equal(t, joinSlices(parts), res.Code()) +} + +func TestInt(t *testing.T) { + src := ` + push_int 1 + push_int 2 + add + ret + ` + + c := text.NewCompiler(strings.NewReader(src)) + res, err := c.Compile() + require.NoError(t, err) + + parts := [][]byte{ + opBin(code.OpPushInt), + uintBin(1), + opBin(code.OpPushInt), + uintBin(2), + opBin(code.OpAdd), + opBin(code.OpRet), + } + + require.Equal(t, joinSlices(parts), res.Code()) +} + +func TestFloat(t *testing.T) { + src := ` + push_float 3.1415 + push_float -2.71828 + ` + + c := text.NewCompiler(strings.NewReader(src)) + res, err := c.Compile() + require.NoError(t, err) + + parts := [][]byte{ + opBin(code.OpPushFloat), + floatBin(3.1415), + opBin(code.OpPushFloat), + floatBin(-2.71828), + } + + require.Equal(t, joinSlices(parts), res.Code()) +} + +func TestString(t *testing.T) { + src := ` + push_string "Hello, " + push_string "world!" + add + ` + + c := text.NewCompiler(strings.NewReader(src)) + res, err := c.Compile() + require.NoError(t, err) + + parts := [][]byte{ + opBin(code.OpPushString), + stringBin("Hello, "), + opBin(code.OpPushString), + stringBin("world!"), + opBin(code.OpAdd), + } + + require.Equal(t, joinSlices(parts), res.Code()) +} + +func TestLabels(t *testing.T) { + src := ` + @1: + nop + @2: + nop + @3: + nop + jmp @1 + jmp @2 + jmp @3 + ` + + c := text.NewCompiler(strings.NewReader(src)) + res, err := c.Compile() + require.NoError(t, err) + + parts := [][]byte{ + opBin(code.OpNop), + opBin(code.OpNop), + opBin(code.OpNop), + opBin(code.OpJmp), + uintBin(0), + opBin(code.OpJmp), + uintBin(1), + opBin(code.OpJmp), + uintBin(2), + } + + require.Equal(t, joinSlices(parts), res.Code()) +} + +func opBin(op code.Op) []byte { + return []byte{byte(op)} +} + +func uintBin(x uint64) []byte { + res := make([]byte, 8) + binary.LittleEndian.PutUint64(res, x) + return res +} + +func floatBin(x float64) []byte { + res := make([]byte, 8) + binary.LittleEndian.PutUint64(res, math.Float64bits(x)) + return res +} + +func stringBin(x string) []byte { + res := []byte(x) + res = append(res, 0) + return res +} + +func joinSlices[T any](slices [][]T) []T { + res := []T{} + + for _, slice := range slices { + res = append(res, slice...) + } + + return res +} diff --git a/pkg/lang/vm/text/errors.go b/pkg/lang/vm/text/errors.go new file mode 100644 index 0000000..a734a61 --- /dev/null +++ b/pkg/lang/vm/text/errors.go @@ -0,0 +1,33 @@ +package text + +type ErrInvalidValue struct { + Value string +} + +func (e ErrInvalidValue) Error() string { + return "invalid value: " + e.Value +} + +type ErrDuplicateLabel struct { + Label string +} + +func (e ErrDuplicateLabel) Error() string { + return "duplicate label: " + e.Label +} + +type ErrUnkonwnLabel struct { + Label string +} + +func (e ErrUnkonwnLabel) Error() string { + return "unknown label: " + e.Label +} + +type ErrUnknownOp struct { + Op string +} + +func (e ErrUnknownOp) Error() string { + return "unknown op: " + e.Op +} diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go new file mode 100644 index 0000000..a8f3663 --- /dev/null +++ b/pkg/lang/vm/text/op.go @@ -0,0 +1,56 @@ +package text + +import "jinx/pkg/lang/vm/code" + +var ( + opToString = map[code.Op]string{ + code.OpNop: "nop", + code.OpHalt: "halt", + code.OpPushInt: "push_int", + code.OpPushFloat: "push_float", + code.OpPushString: "push_string", + code.OpPushTrue: "push_true", + code.OpPushFalse: "push_false", + code.OpPushNull: "push_null", + code.OpPushArray: "push_array", + code.OpPushFunction: "push_function", + code.OpPushObject: "push_object", + code.OpGetGlobal: "get_global", + code.OpGetLocal: "get_local", + code.OpGetMember: "get_member", + code.OpGetArg: "get_arg", + code.OpGetEnv: "get_env", + code.OpAdd: "add", + code.OpSub: "sub", + code.OpIndex: "index", + code.OpCall: "call", + code.OpJmp: "jmp", + code.OpJez: "jez", + code.OpRet: "ret", + } + stringToOp = reverseMap(opToString) +) + +func OpToString(op code.Op) string { + str, ok := opToString[op] + if !ok { + return "unknown" + } + return str +} + +func StringToOp(str string) (code.Op, error) { + op, ok := stringToOp[str] + if !ok { + return 0, ErrUnknownOp{Op: str} + } + return op, nil +} + +func reverseMap[K comparable, V comparable](m map[K]V) map[V]K { + r := make(map[V]K) + for k, v := range m { + r[v] = k + } + return r +} |
