Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,29 @@ def classes(self) -> np.ndarray:

def construct_head(self) -> nn.Sequential:
"""Constructs a simple classifier head."""
modules: list[nn.Module] = []
if self.n_layers == 0:
return nn.Sequential(nn.Linear(self.embed_dim, self.out_dim))
modules = [
nn.Linear(self.embed_dim, self.hidden_dim),
nn.ReLU(),
]
for _ in range(self.n_layers - 1):
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
modules.extend([nn.Linear(self.hidden_dim, self.out_dim)])

for module in modules:
if isinstance(module, nn.Linear):
nn.init.kaiming_uniform_(module.weight)
modules.append(nn.Linear(self.embed_dim, self.out_dim))
else:
# If we have a hidden layer, we should first project to hidden_dim
modules = [
nn.Linear(self.embed_dim, self.hidden_dim),
nn.ReLU(),
]
for _ in range(self.n_layers - 1):
modules.extend([nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU()])
# We always have a layer mapping from hidden to out.
modules.append(nn.Linear(self.hidden_dim, self.out_dim))

linear_modules = [module for module in modules if isinstance(module, nn.Linear)]
if linear_modules:
*initial, last = linear_modules
for module in initial:
nn.init.kaiming_uniform_(module.weight, nonlinearity="relu")
nn.init.zeros_(module.bias)
# Final layer does not kaiming
nn.init.xavier_uniform_(last.weight)
nn.init.zeros_(last.bias)

return nn.Sequential(*modules)

Expand Down