Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/dev/13698.other.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add optional low-variance ("hat") regularization to :func:`mne.stats.f_oneway` via new ``sigma`` and ``method`` parameters, by `Aniket Singh Yadav`_.
17 changes: 16 additions & 1 deletion mne/stats/parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def ttest_ind_no_p(a, b, equal_var=True, sigma=0.0):
return t


def f_oneway(*args):
def f_oneway(*args, sigma=0.0, method="absolute"):
"""Perform a 1-way ANOVA.

The one-way ANOVA tests the null hypothesis that 2 or more groups have
Expand All @@ -125,6 +125,13 @@ def f_oneway(*args):
----------
*args : array_like
The sample measurements should be given as arguments.
sigma : float
Regularization parameter applied to the within-group variance
(``MS_within``) to mitigate inflation of the F-statistic under
low-variance conditions.
method : {'absolute', 'relative'}
Strategy used to regularize the within-group variance when
``sigma > 0``:

Returns
-------
Expand All @@ -151,6 +158,9 @@ def f_oneway(*args):
----------
.. footbibliography::
"""
_check_option("method", method, ["absolute", "relative"])
if sigma < 0:
raise ValueError(f"sigma must be >= 0, got {sigma}")
n_classes = len(args)
n_samples_per_class = np.array([len(a) for a in args])
n_samples = np.sum(n_samples_per_class)
Expand All @@ -168,6 +178,11 @@ def f_oneway(*args):
dfwn = n_samples - n_classes
msb = ssbn / float(dfbn)
msw = sswn / float(dfwn)
if sigma > 0:
if method == "absolute":
msw = msw + sigma
else:
msw = msw * (1.0 + sigma)
f = msb / msw
return f

Expand Down
50 changes: 49 additions & 1 deletion mne/stats/tests/test_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_less

import mne
from mne.stats.parametric import _map_effects, f_mway_rm, f_threshold_mway_rm
from mne.stats.parametric import _map_effects, f_mway_rm, f_oneway, f_threshold_mway_rm

# hardcoded external test results, manually transferred
test_external = {
Expand Down Expand Up @@ -175,3 +175,51 @@ def theirs(*a, **kw):
# something to the divisor (var)
assert_allclose(got, want, rtol=2e-1, atol=1e-2)
assert_array_less(np.abs(got), np.abs(want))


@pytest.mark.parametrize("sigma", [0.0, 1e-3])
@pytest.mark.parametrize("method", ["absolute", "relative"])
@pytest.mark.parametrize("seed", [0, 42, 1337])
def test_f_oneway_hat(sigma, method, seed):
"""Test f_oneway hat (low-variance) regularization."""
rng = np.random.RandomState(seed)
X1 = rng.randn(10, 50)
X2 = rng.randn(10, 50)

f_ours = f_oneway(X1, X2, sigma=0.0, method=method)
f_scipy = scipy.stats.f_oneway(X1, X2)[0]
assert_allclose(f_ours, f_scipy, rtol=1e-7, atol=1e-6)

if sigma > 0:
f_reg = f_oneway(X1, X2, sigma=sigma, method=method)
f_unreg = f_oneway(X1, X2, sigma=0.0)
pos = f_unreg > 0
assert_array_less(f_reg[pos], f_unreg[pos] + 1e-10)


def test_f_oneway_hat_small_variance():
"""Test that f_oneway hat stabilizes F-values for near-zero variance."""
rng = np.random.RandomState(0)
X1 = rng.normal(0, 1e-6, (10, 100))
X2 = rng.normal(1, 1e-6, (10, 100))

f_unreg = f_oneway(X1, X2, sigma=0.0)
f_abs = f_oneway(X1, X2, sigma=1e-3, method="absolute")
f_rel = f_oneway(X1, X2, sigma=1e-3, method="relative")

assert np.median(f_unreg) > 1e6
assert np.median(f_abs) < np.median(f_unreg)
assert np.median(f_rel) < np.median(f_unreg)


def test_f_oneway_hat_input_validation():
"""Test f_oneway input validation for sigma and method."""
rng = np.random.RandomState(0)
X1 = rng.randn(5, 10)
X2 = rng.randn(5, 10)

with pytest.raises(ValueError, match="sigma must be >= 0"):
f_oneway(X1, X2, sigma=-0.1)

with pytest.raises(ValueError, match="method"):
f_oneway(X1, X2, sigma=1e-3, method="invalid")