Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion mapreader/classify/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
is_inception: bool = False,
load_path: str | None = None,
force_device: bool = False,
huggingface=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a type hint here

Suggested change
huggingface=False,
huggingface: bool = False,

**kwargs,
):
# set up device
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(
)

self.labels_map = labels_map
self.huggingface = huggingface
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would skip setting self.huggingface since its only referenced further down in the init and instead just use the huggingface value in the if statement below


# set up model and move to device
print("[INFO] Initializing model.")
Expand All @@ -149,7 +151,25 @@ def __init__(
self.input_size = input_size
self.is_inception = is_inception
elif isinstance(model, str):
self._initialize_model(model, **kwargs)
if self.huggingface:
try:
from transformers import AutoModelForImageClassification, AutoImageProcessor
except ImportError:
raise ImportError(
"Hugging Face models require the 'transformers' library: 'pip install transformers'."
)
print(f"[INFO] Initializing Hugging Face model: {model}")
num_labels = len(self.labels_map)
self.model = AutoModelForImageClassification.from_pretrained(
model,
num_labels=num_labels,
ignore_mismatched_sizes=True
).to(self.device)
self.hf_processor = AutoImageProcessor.from_pretrained(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could also be just hf_processor instead of an attribute self.hf_processor since it isn't used outside of this function

self.input_size = getattr(self.hf_processor, "size", {"height": 224})["height"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking here there seem to be 3 options for how size is defined.

Could you implement these, i.e.:

Suggested change
self.input_size = getattr(self.hf_processor, "size", {"height": 224})["height"]
size = getattr(hf_processor, "size", {})
if "height" in size and "width" in size:
self.input_size = (size["height"], size["width"])
elif "shortest_edge" in size:
self.input_size = (size["shortest_edge"], size["shortest_edge"])
else:
self.input_size = input_size

self.is_inception = False
else:
self._initialize_model(model, **kwargs)

self.optimizer = None
self.scheduler = None
Expand Down
Loading