about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-07-28 22:11:02 +0000
committerMel <einebeere@gmail.com>2022-07-28 22:11:02 +0000
commit5a6d4664e4417763b4a7d9f215e42102fa1b3fd4 (patch)
tree525f8151bd1bb604ce015425126c5f3dfc84a32c /pkg
parent95c742ef729a657198be43dc2f295f249860332f (diff)
downloadjinx-5a6d4664e4417763b4a7d9f215e42102fa1b3fd4.tar.zst
jinx-5a6d4664e4417763b4a7d9f215e42102fa1b3fd4.zip
Compile type declarations correctly
Diffstat (limited to 'pkg')
-rw-r--r--pkg/lang/compiler/compiler.go249
-rw-r--r--pkg/lang/compiler/compiler_test.go80
-rw-r--r--pkg/lang/compiler/scope/scope_chain.go10
-rw-r--r--pkg/lang/compiler/scope/scopes.go32
-rw-r--r--pkg/lang/vm/exec.go44
-rw-r--r--pkg/lang/vm/value/data.go11
6 files changed, 375 insertions, 51 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index 9c4bda4..8010b16 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -52,7 +52,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 		fnDeclStmt := stmt.Value.(ast.StmtFnDecl)
 		err = comp.compileFnDeclStmt(t, fnDeclStmt)
 	case ast.StmtKindTypeDecl:
-		panic("type declaration statements not implemented")
+		typeDeclStmt := stmt.Value.(ast.StmtTypeDecl)
+		err = comp.compileTypeDeclStmt(t, typeDeclStmt)
 	case ast.StmtKindVarDecl:
 		decl := stmt.Value.(ast.StmtVarDecl)
 		err = comp.compileVarDeclStmt(t, decl)
@@ -102,29 +103,13 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 
 	comp.scopes.EnterFunction(marker)
 
-	for _, arg := range fnDeclStmt.Args {
-		if _, ok := comp.scopes.Declare(arg.Value); !ok {
-			return fmt.Errorf("variable %s already declared", arg.Value)
-		}
-	}
-
-	if err := comp.compileBlockNode(&functionTarget, fnDeclStmt.Body); err != nil {
+	if err := comp.compileFn(&functionTarget, fnDeclStmt.Body, fnDeclStmt.Args, true); err != nil {
 		return err
 	}
 
-	// If the function did not end with a return statement, we need to add an OpRet for safety.
-	lastStmt := fnDeclStmt.Body.Stmts[len(fnDeclStmt.Body.Stmts)-1]
-	// TODO: Get rid of EmptyStmt so we can use the Kind field to determine if the last statement is a return statement.
-	if lastStmt.Kind != ast.StmtKindReturn {
-		functionTarget.AppendOp(code.OpPushNull)
-		functionTarget.AppendOp(code.OpRet)
-	}
-
 	fnScope := comp.scopes.CurrentFunction()
 	_ = comp.scopes.Exit() // Function declaration scopes do not pollute stack
 
-	comp.funcs = append(comp.funcs, &functionTarget)
-
 	// Put the function value on the stack
 
 	t.AppendOp(code.OpPushFunction)
@@ -148,6 +133,106 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 	return nil
 }
 
+func (comp *Compiler) compileTypeDeclStmt(t *code.Builder, typeDeclStmt ast.StmtTypeDecl) error {
+	typeLocal, ok := comp.scopes.Declare(typeDeclStmt.Name.Value)
+	if !ok {
+		return fmt.Errorf("type %s already declared", typeDeclStmt.Name.Value)
+	}
+
+	t.AppendOp(code.OpPushType)
+	t.AppendString(typeDeclStmt.Name.Value)
+
+	parentTypeMarker := comp.scopes.CreateFunctionSubUnit(typeDeclStmt.Name.Value)
+
+	constructorDeclared := false
+	declaredMethods := make(map[string]struct{})
+
+	// Compile the methods
+	for _, method := range typeDeclStmt.Methods {
+		methodTarget := code.NewBuilder()
+
+		if method.IsConstructor {
+			if constructorDeclared {
+				return fmt.Errorf("constructor for type %s already declared", typeDeclStmt.Name.Value)
+			}
+			constructorDeclared = true
+
+			initMarker := parentTypeMarker.SubMarker("$init")
+			methodTarget.PutMarker(initMarker)
+
+			comp.scopes.EnterConstructor(initMarker, len(method.Args))
+
+			methodTarget.AppendOp(code.OpPushObject)
+			methodTarget.AppendOp(code.OpGetEnv)
+			methodTarget.AppendInt(int64(0))
+			methodTarget.AppendOp(code.OpAnchorType)
+		} else {
+			if _, ok := declaredMethods[method.Name.Value]; ok {
+				return fmt.Errorf("method %s for type %s already declared", method.Name.Value, typeDeclStmt.Name.Value)
+			}
+			declaredMethods[method.Name.Value] = struct{}{}
+
+			methodMarker := parentTypeMarker.SubMarker(method.Name.Value)
+			methodTarget.PutMarker(methodMarker)
+
+			if !method.HasThis {
+				panic("static methods not implemented")
+			}
+
+			comp.scopes.EnterMethod(methodMarker)
+		}
+
+		if err := comp.compileFn(&methodTarget, method.Body, method.Args, !method.IsConstructor); err != nil {
+			return err
+		}
+
+		if method.IsConstructor {
+			// Return the constructed object
+			_, thisLocal := comp.scopes.CurrentFunction().ThisLocal()
+			methodTarget.AppendOp(code.OpGetLocal)
+			methodTarget.AppendInt(int64(thisLocal))
+			methodTarget.AppendOp(code.OpRet)
+		}
+
+		_ = comp.scopes.Exit()
+	}
+
+	if !constructorDeclared {
+		return fmt.Errorf("type %s must have a constructor", typeDeclStmt.Name.Value)
+	}
+
+	// Add methods to the type
+	for _, method := range typeDeclStmt.Methods {
+		t.AppendOp(code.OpGetLocal)
+		t.AppendInt(int64(typeLocal))
+
+		t.AppendOp(code.OpGetMember)
+		t.AppendString("$add_method")
+
+		name := method.Name.Value
+		if method.IsConstructor {
+			name = "$init"
+		}
+
+		t.AppendOp(code.OpPushString)
+		t.AppendString(name)
+
+		marker := parentTypeMarker.SubMarker(name)
+		t.AppendOp(code.OpPushFunction)
+		t.AppendMarkerReference(marker)
+
+		if len(method.Args) != 0 {
+			t.AppendOp(code.OpSetArgCount)
+			t.AppendInt(int64(len(method.Args)))
+		}
+
+		t.AppendOp(code.OpCall)
+		t.AppendInt(int64(2))
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
 	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
@@ -471,6 +556,8 @@ func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 		return comp.compileCallExpr(t, expr.Value.(ast.ExprCall))
 	case ast.ExprKindSubscription:
 		return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription))
+	case ast.ExprKindMember:
+		return comp.compileMemberExpr(t, expr.Value.(ast.ExprMember))
 
 	case ast.ExprKindGroup:
 		return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup))
@@ -491,7 +578,7 @@ func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	case ast.ExprKindNullLit:
 		return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit))
 	case ast.ExprKindThis:
-		panic("not implemented")
+		return comp.compileThisExpr(t, expr.Value.(ast.ExprThis))
 	default:
 		panic("unknown expression kind")
 	}
@@ -543,33 +630,48 @@ func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) er
 }
 
 func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error {
-	if expr.Left.Kind != ast.ExprKindIdent {
-		return fmt.Errorf("lvalues other than identifiers not implemented")
-	}
+	switch expr.Left.Kind {
+	case ast.ExprKindIdent:
+		name := expr.Left.Value.(ast.ExprIdent).Value.Value
+		symbolId, ok := comp.scopes.Lookup(name)
+		if !ok {
+			return fmt.Errorf("variable %s not declared", name)
+		}
 
-	name := expr.Left.Value.(ast.ExprIdent).Value.Value
-	symbolId, ok := comp.scopes.Lookup(name)
-	if !ok {
-		return fmt.Errorf("variable %s not declared", name)
-	}
+		if err := comp.compileExpr(t, expr.Right); err != nil {
+			return err
+		}
 
-	if err := comp.compileExpr(t, expr.Right); err != nil {
-		return err
-	}
+		switch symbolId.SymbolKind() {
+		case scope.SymbolKindVariable:
+			symbol := comp.scopes.GetVariable(symbolId)
 
-	switch symbolId.SymbolKind() {
-	case scope.SymbolKindVariable:
-		symbol := comp.scopes.GetVariable(symbolId)
+			t.AppendOp(code.OpSetLocal)
+			t.AppendInt(int64(symbol.Data().LocalIndex()))
+		case scope.SymbolKindEnv:
+			symbol := comp.scopes.GetEnv(symbolId)
 
-		t.AppendOp(code.OpSetLocal)
-		t.AppendInt(int64(symbol.Data().LocalIndex()))
-	case scope.SymbolKindEnv:
-		symbol := comp.scopes.GetEnv(symbolId)
+			t.AppendOp(code.OpSetEnv)
+			t.AppendInt(int64(symbol.Data().IndexInEnv()))
+		default:
+			panic("unknown symbol kind")
+		}
+	case ast.ExprKindMember:
+		memberExpr := expr.Left.Value.(ast.ExprMember)
+		if err := comp.compileExpr(t, memberExpr.Obj); err != nil {
+			return err
+		}
 
-		t.AppendOp(code.OpSetEnv)
-		t.AppendInt(int64(symbol.Data().IndexInEnv()))
+		if err := comp.compileExpr(t, expr.Right); err != nil {
+			return err
+		}
+
+		name := memberExpr.Key.Value
+
+		t.AppendOp(code.OpSetMember)
+		t.AppendString(name)
 	default:
-		panic("unknown symbol kind")
+		return fmt.Errorf("invalid left-hand side of assignment to %v", expr.Left.Kind)
 	}
 
 	return nil
@@ -622,6 +724,19 @@ func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubs
 	return nil
 }
 
+func (comp *Compiler) compileMemberExpr(t *code.Builder, memberExpr ast.ExprMember) error {
+	if err := comp.compileExpr(t, memberExpr.Obj); err != nil {
+		return err
+	}
+
+	name := memberExpr.Key.Value
+
+	t.AppendOp(code.OpGetMember)
+	t.AppendString(name)
+
+	return nil
+}
+
 func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error {
 	return comp.compileExpr(t, expr.Value)
 }
@@ -705,6 +820,23 @@ func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit)
 	return nil
 }
 
+func (comp *Compiler) compileThisExpr(t *code.Builder, expr ast.ExprThis) error {
+	currentFn := comp.scopes.CurrentFunction()
+	if !currentFn.IsMethod() {
+		return fmt.Errorf("this can only be used in methods")
+	}
+
+	if isLocal, localIndex := currentFn.ThisLocal(); isLocal {
+		t.AppendOp(code.OpGetLocal)
+		t.AppendInt(int64(localIndex))
+	} else {
+		t.AppendOp(code.OpGetEnv)
+		t.AppendInt(int64((0)))
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error {
 	for _, stmt := range block.Stmts {
 		if err := comp.compileStmt(t, stmt); err != nil {
@@ -715,6 +847,43 @@ func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) err
 	return nil
 }
 
+func (comp *Compiler) compileFn(t *code.Builder, block ast.BlockNode, args []ast.IdentNode, addMissingReturn bool) error {
+	// Arguments are declared in reverse
+	for i := len(args) - 1; i >= 0; i-- {
+		arg := args[i]
+		if _, ok := comp.scopes.Declare(arg.Value); !ok {
+			return fmt.Errorf("variable %s already declared", arg.Value)
+		}
+	}
+
+	// If we're in a constructor, we need to declare the this argument,
+	// which comes after the normal arguments.
+	currentFn := comp.scopes.CurrentFunction()
+	if isLocal, localIndex := currentFn.ThisLocal(); isLocal {
+		if localIndex != comp.scopes.DeclareAnonymous() {
+			panic("this local did not match expected position")
+		}
+	}
+
+	if err := comp.compileBlockNode(t, block); err != nil {
+		return err
+	}
+
+	if addMissingReturn {
+		// If the function did not end with a return statement, we need to add an OpRet for safety.
+		lastStmt := block.Stmts[len(block.Stmts)-1]
+		// TODO: Get rid of EmptyStmt so we can use the Kind field to determine if the last statement is a return statement.
+		if lastStmt.Kind != ast.StmtKindReturn {
+			t.AppendOp(code.OpPushNull)
+			t.AppendOp(code.OpRet)
+		}
+	}
+
+	comp.funcs = append(comp.funcs, t)
+
+	return nil
+}
+
 func (comp *Compiler) exitScopeAndCleanStack(t *code.Builder) {
 	if stackSpace := comp.scopes.Exit(); stackSpace != 0 {
 		t.AppendOp(code.OpDrop)
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index 9c21a3b..04bf425 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -457,6 +457,7 @@ func TestSimpleFunction(t *testing.T) {
 }
 
 func TestFunctionArgs(t *testing.T) {
+	// TODO: Are arguments in the correct order?
 	src := `
 	fn add(a, b) {
 		return a + b
@@ -477,8 +478,8 @@ func TestFunctionArgs(t *testing.T) {
 	halt
 
 	@add:
-	get_local 0
 	get_local 1
+	get_local 0
 	add
 	ret
 	`
@@ -523,6 +524,83 @@ func TestClosureEnv(t *testing.T) {
 	mustCompileTo(t, src, expected)
 }
 
+func TestType(t *testing.T) {
+	src := `
+	type Cat {
+		(name, age) {
+			this.name = name
+			this.age = age
+		}
+
+		fn meow(this) {
+			return this.name + " says Meow!"
+		}
+	}
+
+	var kitty = Cat("Kitty", 3)
+	kitty.meow()
+	`
+
+	expected := `
+	push_type "Cat"
+
+	get_local 0
+	get_member "$add_method"
+
+	push_string "$init"
+	push_function @Cat:$init
+	set_arg_count 2
+
+	call 2
+
+	get_local 0
+	get_member "$add_method"
+
+	push_string "meow"
+	push_function @Cat:meow
+
+	call 2
+
+	get_local 0
+	push_string "Kitty"
+	push_int 3
+	call 2
+
+	get_local 1
+	get_member "meow"
+	call 0
+	drop 1
+
+	halt
+
+	@Cat:$init:
+		push_object
+		get_env 0
+		anchor_type
+
+		get_local 2
+		get_local 1
+		set_member "name"
+
+		get_local 2
+		get_local 0
+		set_member "age"
+
+		get_local 2
+		ret
+	@Cat:meow:
+		get_env 0
+		get_member "name"
+
+		push_string " says Meow!"
+
+		add
+		ret
+	`
+
+	mustCompileTo(t, src, expected)
+}
+
 func mustCompileTo(t *testing.T, src, expected string) {
 	scanner := scanner.New(strings.NewReader(src))
 	tokens, err := scanner.Scan()
diff --git a/pkg/lang/compiler/scope/scope_chain.go b/pkg/lang/compiler/scope/scope_chain.go
index 1b83c75..f386017 100644
--- a/pkg/lang/compiler/scope/scope_chain.go
+++ b/pkg/lang/compiler/scope/scope_chain.go
@@ -72,6 +72,16 @@ func (sc *ScopeChain) EnterFunction(unit code.Marker) {
 	sc.functionScopes = append(sc.functionScopes, NewFunctionScope(sc.CurrentScopeID(), unit))
 }
 
+func (sc *ScopeChain) EnterConstructor(unit code.Marker, thisLocal int) {
+	sc.Enter()
+	sc.functionScopes = append(sc.functionScopes, NewMethodFunctionScope(sc.CurrentScopeID(), unit, thisLocal))
+}
+
+func (sc *ScopeChain) EnterMethod(unit code.Marker) {
+	sc.Enter()
+	sc.functionScopes = append(sc.functionScopes, NewMethodFunctionScope(sc.CurrentScopeID(), unit, -1))
+}
+
 func (sc *ScopeChain) EnterLoop() (code.Marker, code.Marker) {
 	parentMarker := sc.CreateAnonymousFunctionSubUnit()
 
diff --git a/pkg/lang/compiler/scope/scopes.go b/pkg/lang/compiler/scope/scopes.go
index 7a1b20c..2a9453a 100644
--- a/pkg/lang/compiler/scope/scopes.go
+++ b/pkg/lang/compiler/scope/scopes.go
@@ -28,6 +28,9 @@ type FunctionScope struct {
 	subUnitCount int
 
 	outsideSymbolsInEnv []SymbolID
+
+	isMethod    bool
+	methodLocal int // -1 if in env 0
 }
 
 func NewFunctionScope(id ScopeID, unit code.Marker) FunctionScope {
@@ -40,6 +43,19 @@ func NewFunctionScope(id ScopeID, unit code.Marker) FunctionScope {
 	}
 }
 
+func NewMethodFunctionScope(id ScopeID, unit code.Marker, methodLocal int) FunctionScope {
+	return FunctionScope{
+		id:           id,
+		unit:         unit,
+		subUnitCount: 0,
+
+		outsideSymbolsInEnv: make([]SymbolID, 0),
+
+		isMethod:    true,
+		methodLocal: methodLocal,
+	}
+}
+
 func (sf FunctionScope) ID() ScopeID {
 	return sf.id
 }
@@ -56,6 +72,22 @@ func (sf FunctionScope) IsRootScope() bool {
 	return sf.ID() == ScopeID(0)
 }
 
+func (sf FunctionScope) IsMethod() bool {
+	return sf.isMethod
+}
+
+func (sf FunctionScope) ThisLocal() (bool, int) {
+	if !sf.isMethod {
+		return false, 0
+	}
+
+	if sf.methodLocal == -1 {
+		return false, 0
+	}
+
+	return true, sf.methodLocal
+}
+
 type LoopScope struct {
 	id             ScopeID
 	breakMarker    code.Marker
diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go
index 5a8eb05..353d8ac 100644
--- a/pkg/lang/vm/exec.go
+++ b/pkg/lang/vm/exec.go
@@ -208,17 +208,11 @@ func (vm *VM) execGetMember(name string) error {
 
 	if parent.Type() == value.TypeRefType {
 		ref := parent.Data().(value.TypeRefData)
-		ptr := ref.TypeRef()
-
-		cell, err := vm.getMemCell(ptr, mem.CellKindType, false)
+		ok, methodData, err := ref.GetMethod(vm.memory, name)
 		if err != nil {
 			return err
 		}
 
-		typ := cell.(value.TypeCell).Get()
-
-		methodData, ok := typ.GetMethod(name)
-
 		if !methodData.Env().IsNull() {
 			panic("methods with environments not implemented yet")
 		}
@@ -831,15 +825,45 @@ func (vm *VM) execCall(argCount uint) error {
 		return err
 	}
 
-	if f.Type() != value.FunctionType {
+	var fn value.FunctionData
+	// Constructor call
+	if f.Type() == value.TypeRefType {
+		t := f.Data().(value.TypeRefData)
+		ok, initMethod, err := t.GetMethod(vm.memory, "$init")
+		if err != nil {
+			return err
+		}
+
+		if !ok {
+			panic("constructor not found on type")
+		}
+
+		// TODO: Very unsure about this, it's duplicated from execGetMember, probably need to refactor.
+		if f.Outlet().IsNull() {
+			outletPtr, err := vm.memory.Allocate(mem.CellKindOutlet)
+			if err != nil {
+				return err
+			}
+			f = f.WithOutlet(outletPtr)
+		}
+
+		newEnv := value.NewEnv()
+		newEnv.Add(0, f.Outlet())
+		envPtr, err := newEnv.Allocate(vm.memory)
+		if err != nil {
+			return err
+		}
+
+		fn = initMethod.WithEnv(envPtr)
+	} else if f.Type() == value.FunctionType {
+		fn = f.Data().(value.FunctionData)
+	} else {
 		return ErrInvalidOperandType{
 			Op: code.OpCall,
 			X:  f.Type(),
 		}
 	}
 
-	fn := f.Data().(value.FunctionData)
-
 	if argCount != fn.Args() {
 		return ErrWrongNumberOfArguments{
 			Got:    argCount,
diff --git a/pkg/lang/vm/value/data.go b/pkg/lang/vm/value/data.go
index 6c3d762..6ec7be9 100644
--- a/pkg/lang/vm/value/data.go
+++ b/pkg/lang/vm/value/data.go
@@ -180,6 +180,17 @@ func (t TypeRefData) TypeRef() mem.Ptr {
 	return t.typeRef
 }
 
+func (t TypeRefData) GetMethod(m mem.Mem, name string) (bool, FunctionData, error) {
+	cell, err := m.Get(t.typeRef)
+	if err != nil {
+		return false, FunctionData{}, err
+	}
+
+	typ := cell.(TypeCell).Get()
+	method, ok := typ.Methods[name]
+	return ok, method, nil
+}
+
 func (t TypeRefData) AddMethod(m mem.Mem, name string, method FunctionData) error {
 	cell, err := m.Get(t.typeRef)
 	if err != nil {