diff --git a/builtin/builtin.go b/builtin/builtin.go index 87e73614..4c6b041f 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -263,14 +263,17 @@ var Builtins = []*Function{ }, { Name: "split", - Func: func(args ...any) (any, error) { + Safe: func(args ...any) (any, uint, error) { + var parts []string if len(args) == 2 { - return strings.Split(args[0].(string), args[1].(string)), nil + parts = strings.Split(args[0].(string), args[1].(string)) } else if len(args) == 3 { - return strings.SplitN(args[0].(string), args[1].(string), runtime.ToInt(args[2])), nil + parts = strings.SplitN(args[0].(string), args[1].(string), runtime.ToInt(args[2])) } else { - return nil, fmt.Errorf("invalid number of arguments for split (expected 2 or 3, got %d)", len(args)) + return nil, 0, fmt.Errorf("invalid number of arguments for split (expected 2 or 3, got %d)", len(args)) } + // Charge proportional to number of produced elements (slice growth + headers). + return parts, uint(len(parts)), nil }, Types: types( strings.Split, @@ -279,14 +282,16 @@ var Builtins = []*Function{ }, { Name: "splitAfter", - Func: func(args ...any) (any, error) { + Safe: func(args ...any) (any, uint, error) { + var parts []string if len(args) == 2 { - return strings.SplitAfter(args[0].(string), args[1].(string)), nil + parts = strings.SplitAfter(args[0].(string), args[1].(string)) } else if len(args) == 3 { - return strings.SplitAfterN(args[0].(string), args[1].(string), runtime.ToInt(args[2])), nil + parts = strings.SplitAfterN(args[0].(string), args[1].(string), runtime.ToInt(args[2])) } else { - return nil, fmt.Errorf("invalid number of arguments for splitAfter (expected 2 or 3, got %d)", len(args)) + return nil, 0, fmt.Errorf("invalid number of arguments for splitAfter (expected 2 or 3, got %d)", len(args)) } + return parts, uint(len(parts)), nil }, Types: types( strings.SplitAfter, diff --git a/expr_test.go b/expr_test.go index 1bce3c8d..94645167 100644 --- a/expr_test.go +++ b/expr_test.go @@ -3005,3 +3005,35 @@ func TestBytesLiteral_errors(t *testing.T) { }) } } + +func TestMemoryBudget_SplitBuiltin(t *testing.T) { + type Env struct { + S string `expr:"s"` + } + + in := Env{S: strings.Repeat("a", 200000)} + + program, err := expr.Compile(`split(s, "a")`, expr.Env(Env{})) + require.NoError(t, err, "compile error") + + m := vm.VM{MemoryBudget: 10} + _, err = m.Run(program, in) + require.Error(t, err, "expected memory budget error") + assert.Contains(t, err.Error(), "memory budget exceeded") +} + +func TestMemoryBudget_SplitAfterBuiltin(t *testing.T) { + type Env struct { + S string `expr:"s"` + } + + in := Env{S: strings.Repeat("a", 200000)} + + program, err := expr.Compile(`splitAfter(s, "a")`, expr.Env(Env{})) + require.NoError(t, err, "compile error") + + m := vm.VM{MemoryBudget: 10} + _, err = m.Run(program, in) + require.Error(t, err, "expected memory budget error") + assert.Contains(t, err.Error(), "memory budget exceeded") +}