about summary refs log tree commit diff
path: root/pkg/lang/compiler/compiler.go
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-07-28 22:11:02 +0000
committerMel <einebeere@gmail.com>2022-07-28 22:11:02 +0000
commit5a6d4664e4417763b4a7d9f215e42102fa1b3fd4 (patch)
tree525f8151bd1bb604ce015425126c5f3dfc84a32c /pkg/lang/compiler/compiler.go
parent95c742ef729a657198be43dc2f295f249860332f (diff)
downloadjinx-5a6d4664e4417763b4a7d9f215e42102fa1b3fd4.tar.zst
jinx-5a6d4664e4417763b4a7d9f215e42102fa1b3fd4.zip
Compile type declarations correctly
Diffstat (limited to 'pkg/lang/compiler/compiler.go')
-rw-r--r--pkg/lang/compiler/compiler.go249
1 files changed, 209 insertions, 40 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index 9c4bda4..8010b16 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -52,7 +52,8 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 		fnDeclStmt := stmt.Value.(ast.StmtFnDecl)
 		err = comp.compileFnDeclStmt(t, fnDeclStmt)
 	case ast.StmtKindTypeDecl:
-		panic("type declaration statements not implemented")
+		typeDeclStmt := stmt.Value.(ast.StmtTypeDecl)
+		err = comp.compileTypeDeclStmt(t, typeDeclStmt)
 	case ast.StmtKindVarDecl:
 		decl := stmt.Value.(ast.StmtVarDecl)
 		err = comp.compileVarDeclStmt(t, decl)
@@ -102,29 +103,13 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 
 	comp.scopes.EnterFunction(marker)
 
-	for _, arg := range fnDeclStmt.Args {
-		if _, ok := comp.scopes.Declare(arg.Value); !ok {
-			return fmt.Errorf("variable %s already declared", arg.Value)
-		}
-	}
-
-	if err := comp.compileBlockNode(&functionTarget, fnDeclStmt.Body); err != nil {
+	if err := comp.compileFn(&functionTarget, fnDeclStmt.Body, fnDeclStmt.Args, true); err != nil {
 		return err
 	}
 
-	// If the function did not end with a return statement, we need to add an OpRet for safety.
-	lastStmt := fnDeclStmt.Body.Stmts[len(fnDeclStmt.Body.Stmts)-1]
-	// TODO: Get rid of EmptyStmt so we can use the Kind field to determine if the last statement is a return statement.
-	if lastStmt.Kind != ast.StmtKindReturn {
-		functionTarget.AppendOp(code.OpPushNull)
-		functionTarget.AppendOp(code.OpRet)
-	}
-
 	fnScope := comp.scopes.CurrentFunction()
 	_ = comp.scopes.Exit() // Function declaration scopes do not pollute stack
 
-	comp.funcs = append(comp.funcs, &functionTarget)
-
 	// Put the function value on the stack
 
 	t.AppendOp(code.OpPushFunction)
@@ -148,6 +133,106 @@ func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDe
 	return nil
 }
 
+func (comp *Compiler) compileTypeDeclStmt(t *code.Builder, typeDeclStmt ast.StmtTypeDecl) error {
+	typeLocal, ok := comp.scopes.Declare(typeDeclStmt.Name.Value)
+	if !ok {
+		return fmt.Errorf("type %s already declared", typeDeclStmt.Name.Value)
+	}
+
+	t.AppendOp(code.OpPushType)
+	t.AppendString(typeDeclStmt.Name.Value)
+
+	parentTypeMarker := comp.scopes.CreateFunctionSubUnit(typeDeclStmt.Name.Value)
+
+	constructorDeclared := false
+	declaredMethods := make(map[string]struct{})
+
+	// Compile the methods
+	for _, method := range typeDeclStmt.Methods {
+		methodTarget := code.NewBuilder()
+
+		if method.IsConstructor {
+			if constructorDeclared {
+				return fmt.Errorf("constructor for type %s already declared", typeDeclStmt.Name.Value)
+			}
+			constructorDeclared = true
+
+			initMarker := parentTypeMarker.SubMarker("$init")
+			methodTarget.PutMarker(initMarker)
+
+			comp.scopes.EnterConstructor(initMarker, len(method.Args))
+
+			methodTarget.AppendOp(code.OpPushObject)
+			methodTarget.AppendOp(code.OpGetEnv)
+			methodTarget.AppendInt(int64(0))
+			methodTarget.AppendOp(code.OpAnchorType)
+		} else {
+			if _, ok := declaredMethods[method.Name.Value]; ok {
+				return fmt.Errorf("method %s for type %s already declared", method.Name.Value, typeDeclStmt.Name.Value)
+			}
+			declaredMethods[method.Name.Value] = struct{}{}
+
+			methodMarker := parentTypeMarker.SubMarker(method.Name.Value)
+			methodTarget.PutMarker(methodMarker)
+
+			if !method.HasThis {
+				panic("static methods not implemented")
+			}
+
+			comp.scopes.EnterMethod(methodMarker)
+		}
+
+		if err := comp.compileFn(&methodTarget, method.Body, method.Args, !method.IsConstructor); err != nil {
+			return err
+		}
+
+		if method.IsConstructor {
+			// Return the constructed object
+			_, thisLocal := comp.scopes.CurrentFunction().ThisLocal()
+			methodTarget.AppendOp(code.OpGetLocal)
+			methodTarget.AppendInt(int64(thisLocal))
+			methodTarget.AppendOp(code.OpRet)
+		}
+
+		_ = comp.scopes.Exit()
+	}
+
+	if !constructorDeclared {
+		return fmt.Errorf("type %s must have a constructor", typeDeclStmt.Name.Value)
+	}
+
+	// Add methods to the type
+	for _, method := range typeDeclStmt.Methods {
+		t.AppendOp(code.OpGetLocal)
+		t.AppendInt(int64(typeLocal))
+
+		t.AppendOp(code.OpGetMember)
+		t.AppendString("$add_method")
+
+		name := method.Name.Value
+		if method.IsConstructor {
+			name = "$init"
+		}
+
+		t.AppendOp(code.OpPushString)
+		t.AppendString(name)
+
+		marker := parentTypeMarker.SubMarker(name)
+		t.AppendOp(code.OpPushFunction)
+		t.AppendMarkerReference(marker)
+
+		if len(method.Args) != 0 {
+			t.AppendOp(code.OpSetArgCount)
+			t.AppendInt(int64(len(method.Args)))
+		}
+
+		t.AppendOp(code.OpCall)
+		t.AppendInt(int64(2))
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileVarDeclStmt(t *code.Builder, decl ast.StmtVarDecl) error {
 	if err := comp.compileExpr(t, decl.Value); err != nil {
 		return err
@@ -471,6 +556,8 @@ func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 		return comp.compileCallExpr(t, expr.Value.(ast.ExprCall))
 	case ast.ExprKindSubscription:
 		return comp.compileSubscriptionExpr(t, expr.Value.(ast.ExprSubscription))
+	case ast.ExprKindMember:
+		return comp.compileMemberExpr(t, expr.Value.(ast.ExprMember))
 
 	case ast.ExprKindGroup:
 		return comp.compileGroupExpr(t, expr.Value.(ast.ExprGroup))
@@ -491,7 +578,7 @@ func (comp *Compiler) compileExpr(t *code.Builder, expr ast.Expr) error {
 	case ast.ExprKindNullLit:
 		return comp.compileNullLitExpr(t, expr.Value.(ast.ExprNullLit))
 	case ast.ExprKindThis:
-		panic("not implemented")
+		return comp.compileThisExpr(t, expr.Value.(ast.ExprThis))
 	default:
 		panic("unknown expression kind")
 	}
@@ -543,33 +630,48 @@ func (comp *Compiler) compileBinaryExpr(t *code.Builder, expr ast.ExprBinary) er
 }
 
 func (comp *Compiler) compileAssignExpr(t *code.Builder, expr ast.ExprBinary) error {
-	if expr.Left.Kind != ast.ExprKindIdent {
-		return fmt.Errorf("lvalues other than identifiers not implemented")
-	}
+	switch expr.Left.Kind {
+	case ast.ExprKindIdent:
+		name := expr.Left.Value.(ast.ExprIdent).Value.Value
+		symbolId, ok := comp.scopes.Lookup(name)
+		if !ok {
+			return fmt.Errorf("variable %s not declared", name)
+		}
 
-	name := expr.Left.Value.(ast.ExprIdent).Value.Value
-	symbolId, ok := comp.scopes.Lookup(name)
-	if !ok {
-		return fmt.Errorf("variable %s not declared", name)
-	}
+		if err := comp.compileExpr(t, expr.Right); err != nil {
+			return err
+		}
 
-	if err := comp.compileExpr(t, expr.Right); err != nil {
-		return err
-	}
+		switch symbolId.SymbolKind() {
+		case scope.SymbolKindVariable:
+			symbol := comp.scopes.GetVariable(symbolId)
 
-	switch symbolId.SymbolKind() {
-	case scope.SymbolKindVariable:
-		symbol := comp.scopes.GetVariable(symbolId)
+			t.AppendOp(code.OpSetLocal)
+			t.AppendInt(int64(symbol.Data().LocalIndex()))
+		case scope.SymbolKindEnv:
+			symbol := comp.scopes.GetEnv(symbolId)
 
-		t.AppendOp(code.OpSetLocal)
-		t.AppendInt(int64(symbol.Data().LocalIndex()))
-	case scope.SymbolKindEnv:
-		symbol := comp.scopes.GetEnv(symbolId)
+			t.AppendOp(code.OpSetEnv)
+			t.AppendInt(int64(symbol.Data().IndexInEnv()))
+		default:
+			panic("unknown symbol kind")
+		}
+	case ast.ExprKindMember:
+		memberExpr := expr.Left.Value.(ast.ExprMember)
+		if err := comp.compileExpr(t, memberExpr.Obj); err != nil {
+			return err
+		}
 
-		t.AppendOp(code.OpSetEnv)
-		t.AppendInt(int64(symbol.Data().IndexInEnv()))
+		if err := comp.compileExpr(t, expr.Right); err != nil {
+			return err
+		}
+
+		name := memberExpr.Key.Value
+
+		t.AppendOp(code.OpSetMember)
+		t.AppendString(name)
 	default:
-		panic("unknown symbol kind")
+		return fmt.Errorf("invalid left-hand side of assignment to %v", expr.Left.Kind)
 	}
 
 	return nil
@@ -622,6 +724,19 @@ func (comp *Compiler) compileSubscriptionExpr(t *code.Builder, expr ast.ExprSubs
 	return nil
 }
 
+func (comp *Compiler) compileMemberExpr(t *code.Builder, memberExpr ast.ExprMember) error {
+	if err := comp.compileExpr(t, memberExpr.Obj); err != nil {
+		return err
+	}
+
+	name := memberExpr.Key.Value
+
+	t.AppendOp(code.OpGetMember)
+	t.AppendString(name)
+
+	return nil
+}
+
 func (comp *Compiler) compileGroupExpr(t *code.Builder, expr ast.ExprGroup) error {
 	return comp.compileExpr(t, expr.Value)
 }
@@ -705,6 +820,23 @@ func (comp *Compiler) compileNullLitExpr(t *code.Builder, expr ast.ExprNullLit)
 	return nil
 }
 
+func (comp *Compiler) compileThisExpr(t *code.Builder, expr ast.ExprThis) error {
+	currentFn := comp.scopes.CurrentFunction()
+	if !currentFn.IsMethod() {
+		return fmt.Errorf("this can only be used in methods")
+	}
+
+	if isLocal, localIndex := currentFn.ThisLocal(); isLocal {
+		t.AppendOp(code.OpGetLocal)
+		t.AppendInt(int64(localIndex))
+	} else {
+		t.AppendOp(code.OpGetEnv)
+		t.AppendInt(int64((0)))
+	}
+
+	return nil
+}
+
 func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) error {
 	for _, stmt := range block.Stmts {
 		if err := comp.compileStmt(t, stmt); err != nil {
@@ -715,6 +847,43 @@ func (comp *Compiler) compileBlockNode(t *code.Builder, block ast.BlockNode) err
 	return nil
 }
 
+func (comp *Compiler) compileFn(t *code.Builder, block ast.BlockNode, args []ast.IdentNode, addMissingReturn bool) error {
+	// Arguments are declared in reverse
+	for i := len(args) - 1; i >= 0; i-- {
+		arg := args[i]
+		if _, ok := comp.scopes.Declare(arg.Value); !ok {
+			return fmt.Errorf("variable %s already declared", arg.Value)
+		}
+	}
+
+	// If we're in a constructor, we need to declare the this argument,
+	// which comes after the normal arguments.
+	currentFn := comp.scopes.CurrentFunction()
+	if isLocal, localIndex := currentFn.ThisLocal(); isLocal {
+		if localIndex != comp.scopes.DeclareAnonymous() {
+			panic("this local did not match expected position")
+		}
+	}
+
+	if err := comp.compileBlockNode(t, block); err != nil {
+		return err
+	}
+
+	if addMissingReturn {
+		// If the function did not end with a return statement, we need to add an OpRet for safety.
+		lastStmt := block.Stmts[len(block.Stmts)-1]
+		// TODO: Get rid of EmptyStmt so we can use the Kind field to determine if the last statement is a return statement.
+		if lastStmt.Kind != ast.StmtKindReturn {
+			t.AppendOp(code.OpPushNull)
+			t.AppendOp(code.OpRet)
+		}
+	}
+
+	comp.funcs = append(comp.funcs, t)
+
+	return nil
+}
+
 func (comp *Compiler) exitScopeAndCleanStack(t *code.Builder) {
 	if stackSpace := comp.scopes.Exit(); stackSpace != 0 {
 		t.AppendOp(code.OpDrop)