about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-06-20 00:37:01 +0200
committerMel <einebeere@gmail.com>2022-06-20 00:37:01 +0200
commit621f624f50a7bef16eeed02113b470e79e790cd9 (patch)
tree058767579a4c542b173ee001cc7e89a853d79e25 /pkg
parent9a847f9ec4a0030bf2194005bc9a79cd609cd48a (diff)
downloadjinx-621f624f50a7bef16eeed02113b470e79e790cd9.tar.zst
jinx-621f624f50a7bef16eeed02113b470e79e790cd9.zip
Compile rudimetary variables
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lang/compiler/compiler.go85
-rw-r--r--pkg/lang/compiler/compiler_test.go27
-rw-r--r--pkg/lang/compiler/scope_chain.go76
-rw-r--r--pkg/lang/compiler/symbol.go13
4 files changed, 189 insertions, 12 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index e299fdc..d2b332d 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -1,6 +1,7 @@
 package compiler
 
 import (
+	"fmt"
 	"jinx/pkg/lang/ast"
 	"jinx/pkg/lang/vm/code"
 )
@@ -8,33 +9,57 @@ import (
 type Compiler struct {
 	ast     ast.Program
 	builder code.Builder
+
+	scopes ScopeChain
 }
 
 func New(ast ast.Program) *Compiler {
 	return &Compiler{
 		ast:     ast,
 		builder: code.NewBuilder(),
+
+		scopes: NewScopeChain(),
 	}
 }
 
 func (comp *Compiler) Compile() (code.Code, error) {
 	for _, stmt := range comp.ast.Stmts {
-		if stmt.Kind == ast.StmtKindEmpty {
-			continue
+		if err := comp.compileStmt(stmt); err != nil {
+			return code.Code{}, err
 		}
+	}
 
-		if stmt.Kind != ast.StmtKindExpr {
-			panic("statements other than expressions and empties not implemented")
-		}
+	return comp.builder.Build(), nil
+}
 
+func (comp *Compiler) compileStmt(stmt ast.Stmt) error {
+	var err error
+	switch stmt.Kind {
+	case ast.StmtKindEmpty:
+		// Do nothing.
+	case ast.StmtKindVarDecl:
+		decl := stmt.Value.(ast.StmtVarDecl)
+		err = comp.compileVarDeclStmt(decl)
+	case ast.StmtKindExpr:
 		expr := stmt.Value.(ast.StmtExpr).Value
+		err = comp.compileExpr(expr)
+	default:
+		panic("statements other than expressions, variable declarations, var and empties not implemented")
+	}
 
-		if err := comp.compileExpr(expr); err != nil {
-			return code.Code{}, err
-		}
+	return err
+}
+
+func (comp *Compiler) compileVarDeclStmt(decl ast.StmtVarDecl) error {
+	if !comp.scopes.Declare(decl.Name.Value) {
+		return fmt.Errorf("variable %s already declared", decl.Name.Value)
 	}
 
-	return comp.builder.Build(), nil
+	if err := comp.compileExpr(decl.Value); err != nil {
+		return err
+	}
+
+	return nil
 }
 
 func (comp *Compiler) compileExpr(expr ast.Expr) error {
@@ -55,7 +80,7 @@ func (comp *Compiler) compileExpr(expr ast.Expr) error {
 	case ast.ExprKindArrayLit:
 		panic("not implemented")
 	case ast.ExprKindIdent:
-		panic("not implemented")
+		return comp.compileIdentExpr(expr.Value.(ast.ExprIdent))
 	case ast.ExprKindIntLit:
 		return comp.compileIntLitExpr(expr.Value.(ast.ExprIntLit))
 	case ast.ExprKindFloatLit:
@@ -74,6 +99,10 @@ func (comp *Compiler) compileExpr(expr ast.Expr) error {
 }
 
 func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
+	if expr.Op == ast.BinOpAssign {
+		return comp.compileAssignExpr(expr)
+	}
+
 	if err := comp.compileExpr(expr.Right); err != nil {
 		return err
 	}
@@ -97,8 +126,6 @@ func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
 		// comp.builder.AppendOp(code.OpMod)
 		panic("not implemented")
 
-	case ast.BinOpAssign:
-		panic("not implemented")
 	case ast.BinOpEq:
 		// comp.builder.AppendOp(code.OpEq)
 		panic("not implemented")
@@ -123,6 +150,27 @@ func (comp *Compiler) compileBinaryExpr(expr ast.ExprBinary) error {
 	return nil
 }
 
+func (comp *Compiler) compileAssignExpr(expr ast.ExprBinary) error {
+	if expr.Left.Kind != ast.ExprKindIdent {
+		return fmt.Errorf("lvalues other than identifiers not implemented")
+	}
+
+	name := expr.Left.Value.(ast.ExprIdent).Value.Value
+	symbol, ok := comp.scopes.Lookup(name)
+	if !ok {
+		return fmt.Errorf("variable %s not declared", name)
+	}
+
+	if err := comp.compileExpr(expr.Right); err != nil {
+		return err
+	}
+
+	comp.builder.AppendOp(code.OpSetLocal)
+	comp.builder.AppendInt(int64(symbol.localIndex))
+
+	return nil
+}
+
 func (comp *Compiler) compileUnaryExpr(expr ast.ExprUnary) error {
 	if err := comp.compileExpr(expr.Value); err != nil {
 		return err
@@ -174,6 +222,19 @@ func (comp *Compiler) compileGroupExpr(expr ast.ExprGroup) error {
 	return comp.compileExpr(expr.Value)
 }
 
+func (comp *Compiler) compileIdentExpr(expr ast.ExprIdent) error {
+	symbol, ok := comp.scopes.Lookup(expr.Value.Value)
+	if !ok {
+		return fmt.Errorf("undefined symbol %s", expr.Value.Value)
+	}
+
+	// TODO: Add boundries to check how the symbol should be fetched. (local, env, global, etc.)
+	comp.builder.AppendOp(code.OpGetLocal)
+	comp.builder.AppendInt(int64(symbol.localIndex))
+
+	return nil
+}
+
 func (comp *Compiler) compileIntLitExpr(expr ast.ExprIntLit) error {
 	comp.builder.AppendOp(code.OpPushInt)
 	comp.builder.AppendInt(int64(expr.Value))
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index 5f0cd2c..5347830 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -81,6 +81,33 @@ func TestNestedExpr(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestVars(t *testing.T) {
+	src := `
+	var x = 10
+	var y = 25
+	x = x + 7
+
+	var res = x + y
+	`
+
+	expected := `
+	push_int 10
+
+	push_int 25
+
+	push_int 7
+	get_local 0
+	add
+	set_local 0
+
+	get_local 1
+	get_local 0
+	add
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func mustCompileTo(t *testing.T, src, expected string) {
 	scanner := scanner.New(strings.NewReader(src))
 	tokens, err := scanner.Scan()
diff --git a/pkg/lang/compiler/scope_chain.go b/pkg/lang/compiler/scope_chain.go
new file mode 100644
index 0000000..8d942ea
--- /dev/null
+++ b/pkg/lang/compiler/scope_chain.go
@@ -0,0 +1,76 @@
+package compiler
+
+type ScopeID int
+
+type ScopeChain struct {
+	scopes []Scope
+}
+
+func NewScopeChain() ScopeChain {
+	scopes := make([]Scope, 1)
+	scopes[0] = Scope{
+		kind:    ScopeKindGlobal,
+		symbols: make(map[string]Symbol),
+	}
+
+	return ScopeChain{
+		scopes: scopes,
+	}
+}
+
+func (sc *ScopeChain) Current() *Scope {
+	return &sc.scopes[len(sc.scopes)-1]
+}
+
+func (sc *ScopeChain) Enter(kind ScopeKind) {
+	sc.scopes = append(sc.scopes, Scope{
+		kind:    kind,
+		symbols: make(map[string]Symbol),
+	})
+}
+
+func (sc *ScopeChain) Exit() {
+	sc.scopes[len(sc.scopes)-1] = Scope{}
+	sc.scopes = sc.scopes[:len(sc.scopes)-1]
+}
+
+func (sc *ScopeChain) Declare(name string) bool {
+	// Check whether the symbol is already declared in any of the scopes.
+	for _, scope := range sc.scopes {
+		if _, ok := scope.symbols[name]; ok {
+			return false
+		}
+	}
+
+	// Declare the symbol in the current scope.
+	sc.Current().symbols[name] = Symbol{
+		kind:       SymbolKindVariable,
+		name:       name,
+		localIndex: len(sc.Current().symbols),
+	}
+
+	return true
+}
+
+func (sc *ScopeChain) Lookup(name string) (Symbol, bool) {
+	for i := len(sc.scopes) - 1; i >= 0; i-- {
+		if symbol, ok := sc.scopes[i].symbols[name]; ok {
+			return symbol, true
+		}
+	}
+
+	return Symbol{}, false
+}
+
+type ScopeKind int
+
+const (
+	ScopeKindGlobal ScopeKind = iota
+	ScopeKindFunction
+	ScopeKindBlock
+)
+
+type Scope struct {
+	kind    ScopeKind
+	symbols map[string]Symbol
+}
diff --git a/pkg/lang/compiler/symbol.go b/pkg/lang/compiler/symbol.go
new file mode 100644
index 0000000..03838da
--- /dev/null
+++ b/pkg/lang/compiler/symbol.go
@@ -0,0 +1,13 @@
+package compiler
+
+type SymbolKind int
+
+const (
+	SymbolKindVariable SymbolKind = iota
+)
+
+type Symbol struct {
+	kind       SymbolKind
+	name       string
+	localIndex int
+}