Skip to content

Commit 220e337

Browse files
Copilotxadupre
andcommitted
Fix multiple bugs in plot_template_data.py
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent d02ce9a commit 220e337

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

_doc/examples/ml/plot_template_data.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,9 @@ def select_variables_and_clean(df):
5353
assert set(keys) & set(columns) == set(
5454
keys
5555
), f"Missing columns {set(keys) - set(keys) & set(columns)} in {sorted(df.columns)}"
56-
groups = df[[*keys, cible]].groupby(keys).count()
57-
filtered = groups[groups[cible] > 1].reset_index(drop=False)
58-
59-
mask = filtered.duplicated(subset=keys, keep=False)
60-
return filtered[~mask][[*keys, cible]], cible
56+
subset = df[[*keys, cible]]
57+
mask = subset.duplicated(subset=keys, keep=False)
58+
return subset[~mask].reset_index(drop=True), cible
6159

6260

6361
def compute_oracle(table, cible):
@@ -72,13 +70,13 @@ def compute_oracle(table, cible):
7270
columns="Session",
7371
values=cible,
7472
)
75-
# .dropna(axis=0) # fails
73+
.dropna(axis=0)
7674
.sort_index()
7775
)
7876
return mean_absolute_error(piv[2025], piv[2024])
7977

8078

81-
def split_train_test(table, cuble):
79+
def split_train_test(table, cible):
8280
X, y = table.drop(cible, axis=1), table[cible]
8381

8482
train_test = X["Session"] < 2025
@@ -87,13 +85,13 @@ def split_train_test(table, cuble):
8785

8886
train_X = X[train_test].drop(drop, axis=1)
8987
train_y = y[train_test]
90-
test_X = X[train_test].drop(drop, axis=1)
91-
test_y = y[train_test]
88+
test_X = X[~train_test].drop(drop, axis=1)
89+
test_y = y[~train_test]
9290
return train_X, test_X, train_y, test_y
9391

9492

9593
def make_pipeline(table, cible):
96-
vars = [c for c in table.columns if c != "cible"]
94+
vars = [c for c in table.columns if c != cible]
9795
num_cols = ["Capacité de l’établissement par formation"]
9896
cat_cols = [c for c in vars if c not in num_cols]
9997

0 commit comments

Comments
 (0)