diff --git a/doc/changes/dev/13698.other.rst b/doc/changes/dev/13698.other.rst new file mode 100644 index 00000000000..fab48e560d1 --- /dev/null +++ b/doc/changes/dev/13698.other.rst @@ -0,0 +1 @@ +Add optional low-variance ("hat") regularization to :func:`mne.stats.f_oneway` via new ``sigma`` and ``method`` parameters, by `Aniket Singh Yadav`_. \ No newline at end of file diff --git a/mne/stats/parametric.py b/mne/stats/parametric.py index 2cc0bff2ea1..ad60bd70c6b 100644 --- a/mne/stats/parametric.py +++ b/mne/stats/parametric.py @@ -111,7 +111,7 @@ def ttest_ind_no_p(a, b, equal_var=True, sigma=0.0): return t -def f_oneway(*args): +def f_oneway(*args, sigma=0.0, method="absolute"): """Perform a 1-way ANOVA. The one-way ANOVA tests the null hypothesis that 2 or more groups have @@ -125,6 +125,13 @@ def f_oneway(*args): ---------- *args : array_like The sample measurements should be given as arguments. + sigma : float + Regularization parameter applied to the within-group variance + (``MS_within``) to mitigate inflation of the F-statistic under + low-variance conditions. + method : {'absolute', 'relative'} + Strategy used to regularize the within-group variance when + ``sigma > 0``: Returns ------- @@ -151,6 +158,9 @@ def f_oneway(*args): ---------- .. footbibliography:: """ + _check_option("method", method, ["absolute", "relative"]) + if sigma < 0: + raise ValueError(f"sigma must be >= 0, got {sigma}") n_classes = len(args) n_samples_per_class = np.array([len(a) for a in args]) n_samples = np.sum(n_samples_per_class) @@ -168,6 +178,11 @@ def f_oneway(*args): dfwn = n_samples - n_classes msb = ssbn / float(dfbn) msw = sswn / float(dfwn) + if sigma > 0: + if method == "absolute": + msw = msw + sigma + else: + msw = msw * (1.0 + sigma) f = msb / msw return f diff --git a/mne/stats/tests/test_parametric.py b/mne/stats/tests/test_parametric.py index 61ecbc43af3..c4dc70e570e 100644 --- a/mne/stats/tests/test_parametric.py +++ b/mne/stats/tests/test_parametric.py @@ -11,7 +11,7 @@ from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less import mne -from mne.stats.parametric import _map_effects, f_mway_rm, f_threshold_mway_rm +from mne.stats.parametric import _map_effects, f_mway_rm, f_oneway, f_threshold_mway_rm # hardcoded external test results, manually transferred test_external = { @@ -175,3 +175,51 @@ def theirs(*a, **kw): # something to the divisor (var) assert_allclose(got, want, rtol=2e-1, atol=1e-2) assert_array_less(np.abs(got), np.abs(want)) + + +@pytest.mark.parametrize("sigma", [0.0, 1e-3]) +@pytest.mark.parametrize("method", ["absolute", "relative"]) +@pytest.mark.parametrize("seed", [0, 42, 1337]) +def test_f_oneway_hat(sigma, method, seed): + """Test f_oneway hat (low-variance) regularization.""" + rng = np.random.RandomState(seed) + X1 = rng.randn(10, 50) + X2 = rng.randn(10, 50) + + f_ours = f_oneway(X1, X2, sigma=0.0, method=method) + f_scipy = scipy.stats.f_oneway(X1, X2)[0] + assert_allclose(f_ours, f_scipy, rtol=1e-7, atol=1e-6) + + if sigma > 0: + f_reg = f_oneway(X1, X2, sigma=sigma, method=method) + f_unreg = f_oneway(X1, X2, sigma=0.0) + pos = f_unreg > 0 + assert_array_less(f_reg[pos], f_unreg[pos] + 1e-10) + + +def test_f_oneway_hat_small_variance(): + """Test that f_oneway hat stabilizes F-values for near-zero variance.""" + rng = np.random.RandomState(0) + X1 = rng.normal(0, 1e-6, (10, 100)) + X2 = rng.normal(1, 1e-6, (10, 100)) + + f_unreg = f_oneway(X1, X2, sigma=0.0) + f_abs = f_oneway(X1, X2, sigma=1e-3, method="absolute") + f_rel = f_oneway(X1, X2, sigma=1e-3, method="relative") + + assert np.median(f_unreg) > 1e6 + assert np.median(f_abs) < np.median(f_unreg) + assert np.median(f_rel) < np.median(f_unreg) + + +def test_f_oneway_hat_input_validation(): + """Test f_oneway input validation for sigma and method.""" + rng = np.random.RandomState(0) + X1 = rng.randn(5, 10) + X2 = rng.randn(5, 10) + + with pytest.raises(ValueError, match="sigma must be >= 0"): + f_oneway(X1, X2, sigma=-0.1) + + with pytest.raises(ValueError, match="method"): + f_oneway(X1, X2, sigma=1e-3, method="invalid")