about summary refs log tree commit diff
path: root/pkg/lang/vm
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/lang/vm')
-rw-r--r--pkg/lang/vm/code/op.go2
-rw-r--r--pkg/lang/vm/exec.go31
-rw-r--r--pkg/lang/vm/text/decompiler.go1
-rw-r--r--pkg/lang/vm/text/op.go1
-rw-r--r--pkg/lang/vm/value/data.go4
-rw-r--r--pkg/lang/vm/vm.go6
-rw-r--r--pkg/lang/vm/vm_test.go8
7 files changed, 43 insertions, 10 deletions
diff --git a/pkg/lang/vm/code/op.go b/pkg/lang/vm/code/op.go
index 665742d..1ce08f7 100644
--- a/pkg/lang/vm/code/op.go
+++ b/pkg/lang/vm/code/op.go
@@ -30,6 +30,8 @@ const (
 	OpAddToEnv
 
 	OpAnchorType
+	
+	OpSetArgCount
 
 	OpAdd
 	OpSub
diff --git a/pkg/lang/vm/exec.go b/pkg/lang/vm/exec.go
index 32b1013..f92e486 100644
--- a/pkg/lang/vm/exec.go
+++ b/pkg/lang/vm/exec.go
@@ -433,6 +433,28 @@ func (vm *VM) execAnchorType() error {
 	return nil
 }
 
+func (vm *VM) execSetArgCount(argCount uint) error {
+	f, err := vm.stack.Pop()
+	if err != nil {
+		return err
+	}
+
+	if f.Type() != value.FunctionType {
+		return ErrInvalidOperandType{
+			Op: code.OpSetArgCount,
+			X:  f.Type(),
+		}
+	}
+
+	fn := f.Data().(value.FunctionData)
+
+	fn = fn.WithArgs(argCount)
+	f = f.WithData(fn)
+
+	vm.stack.Push(f)
+	return nil
+}
+
 type binaryOperation = func(value.Value, value.Value) (value.Value, error)
 type typesToBinaryOperation = map[value.TypeKind]map[value.TypeKind]binaryOperation
 
@@ -747,11 +769,10 @@ func (vm *VM) execCall(argCount uint) error {
 	fn := f.Data().(value.FunctionData)
 
 	if argCount != fn.Args() {
-		// TODO: Uncomment when push_function can set fn.Args()
-		// return ErrWrongNumberOfArguments{
-		// 	Got:    argCount,
-		// 	Needed: fn.Args(),
-		// }
+		return ErrWrongNumberOfArguments{
+			Got:    argCount,
+			Needed: fn.Args(),
+		}
 	}
 
 	if err = vm.stack.PushCall(fn.Pc(), vm.pc, fn.Env()); err != nil {
diff --git a/pkg/lang/vm/text/decompiler.go b/pkg/lang/vm/text/decompiler.go
index bef066b..aca024c 100644
--- a/pkg/lang/vm/text/decompiler.go
+++ b/pkg/lang/vm/text/decompiler.go
@@ -75,6 +75,7 @@ func (d *Decompiler) decompileInstruction(bc code.Raw) (string, code.Raw) {
 		code.OpGetEnv,
 		code.OpSetEnv,
 		code.OpAddToEnv,
+		code.OpSetArgCount,
 		code.OpCall:
 		i, rest := d.decompileInt(bc[1:])
 		return fmt.Sprintf("%s %s", opString, i), rest
diff --git a/pkg/lang/vm/text/op.go b/pkg/lang/vm/text/op.go
index d438a25..ce53f50 100644
--- a/pkg/lang/vm/text/op.go
+++ b/pkg/lang/vm/text/op.go
@@ -26,6 +26,7 @@ var (
 		code.OpSetEnv:       "set_env",
 		code.OpAddToEnv:     "add_to_env",
 		code.OpAnchorType:   "anchor_type",
+		code.OpSetArgCount:  "set_arg_count",
 		code.OpAdd:          "add",
 		code.OpSub:          "sub",
 		code.OpMul:          "mul",
diff --git a/pkg/lang/vm/value/data.go b/pkg/lang/vm/value/data.go
index e51904e..18e1b3d 100644
--- a/pkg/lang/vm/value/data.go
+++ b/pkg/lang/vm/value/data.go
@@ -155,6 +155,10 @@ func (f FunctionData) WithEnv(env mem.Ptr) FunctionData {
 	return FunctionData{pc: f.pc, args: f.args, env: env, native: f.native}
 }
 
+func (f FunctionData) WithArgs(args uint) FunctionData {
+	return FunctionData{pc: f.pc, args: args, env: f.env, native: f.native}
+}
+
 func (f FunctionData) String(_ mem.Mem) (string, error) {
 	if f.native != nil {
 		return fmt.Sprintf("<fn(%d) native>", f.args), nil
diff --git a/pkg/lang/vm/vm.go b/pkg/lang/vm/vm.go
index 54fd8ee..ff9c28e 100644
--- a/pkg/lang/vm/vm.go
+++ b/pkg/lang/vm/vm.go
@@ -160,6 +160,12 @@ func (vm *VM) step(op code.Op) (stepDecision, error) {
 	case code.OpAnchorType:
 		err = vm.execAnchorType()
 
+	case code.OpSetArgCount:
+		argCount, advance := vm.code.GetUint(vm.pc)
+		vm.pc += advance
+
+		err = vm.execSetArgCount(uint(argCount))
+
 	case code.OpAdd:
 		err = vm.execAdd()
 	case code.OpSub:
diff --git a/pkg/lang/vm/vm_test.go b/pkg/lang/vm/vm_test.go
index fa43b02..7fe07dd 100644
--- a/pkg/lang/vm/vm_test.go
+++ b/pkg/lang/vm/vm_test.go
@@ -91,18 +91,15 @@ func TestFibonacci(t *testing.T) {
 }
 
 func TestFunction(t *testing.T) {
-	t.Skip("Reimplement arguments")
-
 	src := `
-	push_int 44
 	push_function @subtract_two
+	set_arg_count 1
+	push_int 44
 	call 1
 	halt
 
 	@subtract_two:
-		shift 1
 		push_int 2
-		get_local 0
 		sub
 		ret
 	`
@@ -255,6 +252,7 @@ func TestTypeConstruct(t *testing.T) {
 
 	push_string "$init"
 	push_function @Cat:$init
+	set_arg_count 2
 	
 	call 2