about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lang/vm/code/code.go10
-rw-r--r--pkg/lang/vm/text/compiler.go202
-rw-r--r--pkg/lang/vm/text/compiler_test.go164
-rw-r--r--pkg/lang/vm/text/errors.go33
-rw-r--r--pkg/lang/vm/text/op.go56
5 files changed, 465 insertions, 0 deletions
diff --git a/pkg/lang/vm/code/code.go b/pkg/lang/vm/code/code.go
index ab06d27..bd25fcd 100644
--- a/pkg/lang/vm/code/code.go
+++ b/pkg/lang/vm/code/code.go
@@ -11,10 +11,20 @@ type Code struct {
 	code []byte
 }
 
+func New(code []byte) Code {
+	return Code{
+		code: code,
+	}
+}
+
 func (c *Code) Len() int {
 	return len(c.code)
 }
 
+func (c *Code) Code() []byte {
+	return c.code
+}
+
 func (c *Code) GetOp(at int) (Op, int) {
 	return Op(c.code[at]), 1
 }
diff --git a/pkg/lang/vm/text/compiler.go b/pkg/lang/vm/text/compiler.go
new file mode 100644
index 0000000..2732aea
--- /dev/null
+++ b/pkg/lang/vm/text/compiler.go
@@ -0,0 +1,202 @@
+package text
+
+import (
+	"encoding/binary"
+	"io"
+	"jinx/pkg/lang/vm/code"
+	"jinx/pkg/libs/source"
+	"math"
+	"strconv"
+	"strings"
+	"unicode"
+)
+
+type Compiler struct {
+	src     source.Walker
+	codePos int
+
+	labelPositions  map[string]int
+	labelReferences []labelReference
+}
+
+func NewCompiler(src io.Reader) *Compiler {
+	return &Compiler{
+		src:     *source.NewWalker(src),
+		codePos: 0,
+
+		labelPositions:  map[string]int{},
+		labelReferences: []labelReference{},
+	}
+}
+
+func (cpl *Compiler) Compile() (code.Code, error) {
+	res := []byte{}
+
+	for {
+		_, eof, err := cpl.src.Peek()
+		if err != nil {
+			return code.Code{}, err
+		}
+
+		if eof {
+			break
+		}
+
+		line, err := cpl.compileLine()
+		if err != nil {
+			return code.Code{}, err
+		}
+
+		cpl.codePos += len(line)
+		res = append(res, line...)
+	}
+
+	if err := cpl.linkLabels(res); err != nil {
+		return code.Code{}, err
+	}
+
+	return code.New(res), nil
+}
+
+func (cpl *Compiler) compileLine() ([]byte, error) {
+	res := []byte{}
+
+	start, value, err := cpl.splitLine()
+	if err != nil || start == "" {
+		return nil, err
+	}
+
+	// Ignore lines starting with a comment.
+	if start[0] == '#' {
+		return nil, nil
+	}
+
+	// Save the position of the label.
+	if start[0] == '@' {
+		label := strings.Trim(start, "@:")
+		if _, ok := cpl.labelPositions[label]; ok {
+			return nil, ErrDuplicateLabel{label}
+		}
+		cpl.labelPositions[label] = cpl.codePos
+		return nil, nil
+	}
+
+	// Find the operator.
+	op, err := cpl.compileOp(start)
+	if err != nil {
+		return nil, err
+	}
+
+	res = append(res, byte(op))
+
+	// Find the value, of which there is at most one.
+	if value != "" {
+		val, err := cpl.compileValue(value)
+		if err != nil {
+			return nil, err
+		}
+
+		res = append(res, val...)
+	}
+
+	return res, nil
+}
+
+func (cpl *Compiler) compileOp(str string) (code.Op, error) {
+	op, err := StringToOp(str)
+	if err != nil {
+		return 0, err
+	}
+
+	return op, nil
+}
+
+func (cpl *Compiler) compileValue(str string) ([]byte, error) {
+	res := make([]byte, 8)
+
+	// Save label reference.
+	if str[0] == '@' {
+		label := strings.Trim(str, "@:")
+		cpl.labelReferences = append(cpl.labelReferences, labelReference{
+			label: label,
+			at:    cpl.codePos + 1, // +1 to skip the opcode.
+		})
+		return res, nil
+	}
+
+	if unicode.IsDigit(rune(str[0])) || str[0] == '-' {
+		if strings.Contains(str, ".") {
+			val, err := strconv.ParseFloat(str, 64)
+			if err != nil {
+				return res, err
+			}
+
+			binary.LittleEndian.PutUint64(res, math.Float64bits(val))
+		} else {
+			val, err := strconv.ParseInt(str, 10, 64)
+			if err != nil {
+				return res, err
+			}
+
+			binary.LittleEndian.PutUint64(res, uint64(val))
+		}
+
+		return res, nil
+	}
+
+	if str[0] == '"' {
+		str = strings.Trim(str, "\"")
+		res = []byte(str)
+		res = append(res, 0)
+
+		return res, nil
+	}
+
+	return res, ErrInvalidValue{str}
+}
+
+func (cpl *Compiler) splitLine() (string, string, error) {
+	line := ""
+
+	for {
+		c, eof, err := cpl.src.Next()
+		if err != nil {
+			return "", "", err
+		}
+
+		if eof || c == '\n' {
+			break
+		}
+
+		line += string(c)
+	}
+
+	fields := strings.Fields(line)
+	if len(fields) == 0 {
+		return "", "", nil
+	}
+
+	start := fields[0]
+	if len(fields) > 1 {
+		return start, strings.Join(fields[1:], " "), nil
+	}
+
+	return start, "", nil
+}
+
+type labelReference struct {
+	label string
+	at    int
+}
+
+func (cpl *Compiler) linkLabels(code []byte) error {
+	for _, ref := range cpl.labelReferences {
+		pos, ok := cpl.labelPositions[ref.label]
+		if !ok {
+			return ErrUnkonwnLabel{ref.label}
+		}
+
+		binary.LittleEndian.PutUint64(code[ref.at:], uint64(pos))
+	}
+	return nil
+}
diff --git a/pkg/lang/vm/text/compiler_test.go b/pkg/lang/vm/text/compiler_test.go
new file mode 100644
index 0000000..cf2f6a9
--- /dev/null
+++ b/pkg/lang/vm/text/compiler_test.go
@@ -0,0 +1,164 @@
+package text_test
+
+import (
+	"encoding/binary"
+	"jinx/pkg/lang/vm/code"
+	"jinx/pkg/lang/vm/text"
+	"math"
+	"strings"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestSimple(t *testing.T) {
+	src := `
+	get_arg
+	get_arg
+	sub
+	ret
+	`
+
+	c := text.NewCompiler(strings.NewReader(src))
+	res, err := c.Compile()
+	require.NoError(t, err)
+
+	parts := [][]byte{
+		opBin(code.OpGetArg),
+		opBin(code.OpGetArg),
+		opBin(code.OpSub),
+		opBin(code.OpRet),
+	}
+
+	require.Equal(t, joinSlices(parts), res.Code())
+}
+
+func TestInt(t *testing.T) {
+	src := `
+	push_int 1
+	push_int 2
+	add
+	ret
+	`
+
+	c := text.NewCompiler(strings.NewReader(src))
+	res, err := c.Compile()
+	require.NoError(t, err)
+
+	parts := [][]byte{
+		opBin(code.OpPushInt),
+		uintBin(1),
+		opBin(code.OpPushInt),
+		uintBin(2),
+		opBin(code.OpAdd),
+		opBin(code.OpRet),
+	}
+
+	require.Equal(t, joinSlices(parts), res.Code())
+}
+
+func TestFloat(t *testing.T) {
+	src := `
+	push_float 3.1415
+	push_float -2.71828
+	`
+
+	c := text.NewCompiler(strings.NewReader(src))
+	res, err := c.Compile()
+	require.NoError(t, err)
+
+	parts := [][]byte{
+		opBin(code.OpPushFloat),
+		floatBin(3.1415),
+		opBin(code.OpPushFloat),
+		floatBin(-2.71828),
+	}
+
+	require.Equal(t, joinSlices(parts), res.Code())
+}
+
+func TestString(t *testing.T) {
+	src := `
+	push_string "Hello, "
+	push_string "world!"
+	add
+	`
+
+	c := text.NewCompiler(strings.NewReader(src))
+	res, err := c.Compile()
+	require.NoError(t, err)
+
+	parts := [][]byte{
+		opBin(code.OpPushString),
+		stringBin("Hello, "),
+		opBin(code.OpPushString),
+		stringBin("world!"),
+		opBin(code.OpAdd),
+	}
+
+	require.Equal(t, joinSlices(parts), res.Code())
+}
+
+func TestLabels(t *testing.T) {
+	src := `
+	@1:
+		nop
+	@2:
+		nop
+	@3:
+		nop
+		jmp @1
+		jmp @2
+		jmp @3
+	`
+
+	c := text.NewCompiler(strings.NewReader(src))
+	res, err := c.Compile()
+	require.NoError(t, err)
+
+	parts := [][]byte{
+		opBin(code.OpNop),
+		opBin(code.OpNop),
+		opBin(code.OpNop),
+		opBin(code.OpJmp),
+		uintBin(0),
+		opBin(code.OpJmp),
+		uintBin(1),
+		opBin(code.OpJmp),
+		uintBin(2),
+	}
+
+	require.Equal(t, joinSlices(parts), res.Code())
+}
+
+func opBin(op code.Op) []byte {
+	return []byte{byte(op)}
+}
+
+func uintBin(x uint64) []byte {
+	res := make([]byte, 8)
+	binary.LittleEndian.PutUint64(res, x)
+	return res
+}
+
+func floatBin(x float64) []byte {
+	res := make([]byte, 8)
+	binary.LittleEndian.PutUint64(res, math.Float64bits(x))
+	return res
+}
+
+func stringBin(x string) []byte {
+	res := []byte(x)
+	res = append(res, 0)
+	return res
+}
+
+func joinSlices[T any](slices [][]T) []T {
+	res := []T{}
+
+	for _, slice := range slices {
+		res = append(res, slice...)
+	}
+
+	return res
+}
diff --git a/pkg/lang/vm/text/errors.go b/pkg/lang/vm/text/errors.go
new file mode 100644
index 0000000..a734a61
--- /dev/null
+++ b/pkg/lang/vm/text/errors.go
@@ -0,0 +1,33 @@
+package text
+
+type ErrInvalidValue struct {
+	Value string
+}
+
+func (e ErrInvalidValue) Error() string {
+	return "invalid value: " + e.Value
+}
+
+type ErrDuplicateLabel struct {
+	Label string
+}
+
+func (e ErrDuplicateLabel) Error() string {
+	return "duplicate label: " + e.Label
+}
+
+type ErrUnkonwnLabel struct {
+	Label string
+}
+
+func (e ErrUnkonwnLabel) Error() string {
+	return "unknown label: " + e.Label
+}
+
+type ErrUnknownOp struct {
+	Op string
+}
+
+func (e ErrUnknownOp) Error() string {
+	return "unknown op: " + e.Op
+}
diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go
new file mode 100644
index 0000000..a8f3663
--- /dev/null
+++ b/pkg/lang/vm/text/op.go
@@ -0,0 +1,56 @@
+package text
+
+import "jinx/pkg/lang/vm/code"
+
+var (
+	opToString = map[code.Op]string{
+		code.OpNop:          "nop",
+		code.OpHalt:         "halt",
+		code.OpPushInt:      "push_int",
+		code.OpPushFloat:    "push_float",
+		code.OpPushString:   "push_string",
+		code.OpPushTrue:     "push_true",
+		code.OpPushFalse:    "push_false",
+		code.OpPushNull:     "push_null",
+		code.OpPushArray:    "push_array",
+		code.OpPushFunction: "push_function",
+		code.OpPushObject:   "push_object",
+		code.OpGetGlobal:    "get_global",
+		code.OpGetLocal:     "get_local",
+		code.OpGetMember:    "get_member",
+		code.OpGetArg:       "get_arg",
+		code.OpGetEnv:       "get_env",
+		code.OpAdd:          "add",
+		code.OpSub:          "sub",
+		code.OpIndex:        "index",
+		code.OpCall:         "call",
+		code.OpJmp:          "jmp",
+		code.OpJez:          "jez",
+		code.OpRet:          "ret",
+	}
+	stringToOp = reverseMap(opToString)
+)
+
+func OpToString(op code.Op) string {
+	str, ok := opToString[op]
+	if !ok {
+		return "unknown"
+	}
+	return str
+}
+
+func StringToOp(str string) (code.Op, error) {
+	op, ok := stringToOp[str]
+	if !ok {
+		return 0, ErrUnknownOp{Op: str}
+	}
+	return op, nil
+}
+
+func reverseMap[K comparable, V comparable](m map[K]V) map[V]K {
+	r := make(map[V]K)
+	for k, v := range m {
+		r[v] = k
+	}
+	return r
+}