about summary refs log tree commit diff
path: root/pkg/lang
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lang')
-rw-r--r--pkg/lang/compiler/compiler.go9
-rw-r--r--pkg/lang/compiler/compiler_test.go1
-rw-r--r--pkg/lang/compiler/scope_chain.go3
-rw-r--r--pkg/lang/compiler/symbol.go1
-rw-r--r--pkg/lang/vm/code/op.go2
-rw-r--r--pkg/lang/vm/exec.go31
-rw-r--r--pkg/lang/vm/text/decompiler.go1
-rw-r--r--pkg/lang/vm/text/op.go1
-rw-r--r--pkg/lang/vm/value/data.go4
-rw-r--r--pkg/lang/vm/vm.go6
-rw-r--r--pkg/lang/vm/vm_test.go8
11 files changed, 54 insertions, 13 deletions
diff --git a/pkg/lang/compiler/compiler.go b/pkg/lang/compiler/compiler.go
index ae2a71c..728dec1 100644
--- a/pkg/lang/compiler/compiler.go
+++ b/pkg/lang/compiler/compiler.go
@@ -50,7 +50,7 @@ func (comp *Compiler) Compile() (code.Code, error) {
 }
 
 func (comp *Compiler) preDeclareFunction(fnDeclStmt ast.StmtFnDecl) error {
-	if _, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value); !ok {
+	if _, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value, uint(len(fnDeclStmt.Args))); !ok {
 		return fmt.Errorf("function %s already declared", fnDeclStmt.Name.Value)
 	}
 
@@ -103,7 +103,7 @@ func (comp *Compiler) compileStmt(t *code.Builder, stmt ast.Stmt) error {
 }
 
 func (comp *Compiler) compileFnDeclStmt(t *code.Builder, fnDeclStmt ast.StmtFnDecl) error {
-	marker, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value)
+	marker, ok := comp.scopes.DeclareFunction(fnDeclStmt.Name.Value, uint(len(fnDeclStmt.Args)))
 	if !ok {
 		// If we are in the root scope, the function was simply predeclared :)
 		if comp.scopes.IsRootScope() {
@@ -598,6 +598,11 @@ func (comp *Compiler) compileIdentExpr(t *code.Builder, expr ast.ExprIdent) erro
 
 		t.AppendOp(code.OpPushFunction)
 		t.AppendMarkerReference(symbol.data.marker)
+
+		if symbol.data.args != 0 {
+			t.AppendOp(code.OpSetArgCount)
+			t.AppendInt(int64(symbol.data.args))
+		}
 	default:
 		panic(fmt.Errorf("unknown symbol kind: %v", symbolId.symbolKind))
 	}
diff --git a/pkg/lang/compiler/compiler_test.go b/pkg/lang/compiler/compiler_test.go
index cd62088..bdbc375 100644
--- a/pkg/lang/compiler/compiler_test.go
+++ b/pkg/lang/compiler/compiler_test.go
@@ -383,6 +383,7 @@ func TestFunctionArgs(t *testing.T) {
 
 	expected := `
 	push_function @add
+	set_arg_count 2
 	push_int 4
 	push_int 5
 	call 2
diff --git a/pkg/lang/compiler/scope_chain.go b/pkg/lang/compiler/scope_chain.go
index 3a3b819..6705852 100644
--- a/pkg/lang/compiler/scope_chain.go
+++ b/pkg/lang/compiler/scope_chain.go
@@ -91,7 +91,7 @@ func (sc *ScopeChain) Declare(name string) (int, bool) {
 	return indexInScope, true
 }
 
-func (sc *ScopeChain) DeclareFunction(name string) (code.Marker, bool) {
+func (sc *ScopeChain) DeclareFunction(name string, args uint) (code.Marker, bool) {
 	if _, ok := sc.nameToSymbol[name]; ok {
 		return "", false
 	}
@@ -111,6 +111,7 @@ func (sc *ScopeChain) DeclareFunction(name string) (code.Marker, bool) {
 		name: name,
 		data: SymbolFunction{
 			marker: unitName,
+			args:   args,
 		},
 	})
 
diff --git a/pkg/lang/compiler/symbol.go b/pkg/lang/compiler/symbol.go
index d22cdc0..453cb2b 100644
--- a/pkg/lang/compiler/symbol.go
+++ b/pkg/lang/compiler/symbol.go
@@ -35,4 +35,5 @@ type SymbolVariable struct {
 
 type SymbolFunction struct {
 	marker code.Marker
+	args   uint
 }
diff --git a/pkg/lang/vm/code/op.go b/pkg/lang/vm/code/op.go
index 665742d..1ce08f7 100644
--- a/pkg/lang/vm/code/op.go
+++ b/pkg/lang/vm/code/op.go
@@ -30,6 +30,8 @@ const (
 	OpAddToEnv
 
 	OpAnchorType
+	
+	OpSetArgCount
 
 	OpAdd
 	OpSub
diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go
index 32b1013..f92e486 100644
--- a/pkg/lang/vm/exec.go
+++ b/pkg/lang/vm/exec.go
@@ -433,6 +433,28 @@ func (vm *VM) execAnchorType() error {
 	return nil
 }
 
+func (vm *VM) execSetArgCount(argCount uint) error {
+	f, err := vm.stack.Pop()
+	if err != nil {
+		return err
+	}
+
+	if f.Type() != value.FunctionType {
+		return ErrInvalidOperandType{
+			Op: code.OpSetArgCount,
+			X:  f.Type(),
+		}
+	}
+
+	fn := f.Data().(value.FunctionData)
+
+	fn = fn.WithArgs(argCount)
+	f = f.WithData(fn)
+
+	vm.stack.Push(f)
+	return nil
+}
+
 type binaryOperation = func(value.Value, value.Value) (value.Value, error)
 type typesToBinaryOperation = map[value.TypeKind]map[value.TypeKind]binaryOperation
 
@@ -747,11 +769,10 @@ func (vm *VM) execCall(argCount uint) error {
 	fn := f.Data().(value.FunctionData)
 
 	if argCount != fn.Args() {
-		// TODO: Uncomment when push_function can set fn.Args()
-		// return ErrWrongNumberOfArguments{
-		// 	Got:    argCount,
-		// 	Needed: fn.Args(),
-		// }
+		return ErrWrongNumberOfArguments{
+			Got:    argCount,
+			Needed: fn.Args(),
+		}
 	}
 
 	if err = vm.stack.PushCall(fn.Pc(), vm.pc, fn.Env()); err != nil {
diff --git a/pkg/lang/vm/text/decompiler.go b/pkg/lang/vm/text/decompiler.go
index bef066b..aca024c 100644
--- a/pkg/lang/vm/text/decompiler.go
+++ b/pkg/lang/vm/text/decompiler.go
@@ -75,6 +75,7 @@ func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) {
 		code.OpGetEnv,
 		code.OpSetEnv,
 		code.OpAddToEnv,
+		code.OpSetArgCount,
 		code.OpCall:
 		i, rest := d.decompileInt(bc[1:])
 		return fmt.Sprintf("%s %s", opString, i), rest
diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go
index d438a25..ce53f50 100644
--- a/pkg/lang/vm/text/op.go
+++ b/pkg/lang/vm/text/op.go
@@ -26,6 +26,7 @@ var (
 		code.OpSetEnv:       "set_env",
 		code.OpAddToEnv:     "add_to_env",
 		code.OpAnchorType:   "anchor_type",
+		code.OpSetArgCount:  "set_arg_count",
 		code.OpAdd:          "add",
 		code.OpSub:          "sub",
 		code.OpMul:          "mul",
diff --git a/pkg/lang/vm/value/data.go b/pkg/lang/vm/value/data.go
index e51904e..18e1b3d 100644
--- a/pkg/lang/vm/value/data.go
+++ b/pkg/lang/vm/value/data.go
@@ -155,6 +155,10 @@ func (f FunctionData) WithEnv(env mem.Ptr) FunctionData {
 	return FunctionData{pc: f.pc, args: f.args, env: env, native: f.native}
 }
 
+func (f FunctionData) WithArgs(args uint) FunctionData {
+	return FunctionData{pc: f.pc, args: args, env: f.env, native: f.native}
+}
+
 func (f FunctionData) String(_ mem.Mem) (string, error) {
 	if f.native != nil {
 		return fmt.Sprintf("<fn(%d) native>", f.args), nil
diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go
index 54fd8ee..ff9c28e 100644
--- a/pkg/lang/vm/vm.go
+++ b/pkg/lang/vm/vm.go
@@ -160,6 +160,12 @@ func (vm *VM) step(op code.Op) (stepDecision, error) {
 	case code.OpAnchorType:
 		err = vm.execAnchorType()
 
+	case code.OpSetArgCount:
+		argCount, advance := vm.code.GetUint(vm.pc)
+		vm.pc += advance
+
+		err = vm.execSetArgCount(uint(argCount))
+
 	case code.OpAdd:
 		err = vm.execAdd()
 	case code.OpSub:
diff --git a/pkg/lang/vm/vm_test.go b/pkg/lang/vm/vm_test.go
index fa43b02..7fe07dd 100644
--- a/pkg/lang/vm/vm_test.go
+++ b/pkg/lang/vm/vm_test.go
@@ -91,18 +91,15 @@ func TestFibonacci(t *testing.T) {
 }
 
 func TestFunction(t *testing.T) {
-	t.Skip("Reimplement arguments")
-
 	src := `
-	push_int 44
 	push_function @subtract_two
+	set_arg_count 1
+	push_int 44
 	call 1
 	halt
 
 	@subtract_two:
-		shift 1
 		push_int 2
-		get_local 0
 		sub
 		ret
 	`
@@ -255,6 +252,7 @@ func TestTypeConstruct(t *testing.T) {
 
 	push_string "$init"
 	push_function @Cat:$init
+	set_arg_count 2
 	
 	call 2