diff --git a/src/crawlee/_utils/http.py b/src/crawlee/_utils/http.py new file mode 100644 index 0000000000..be7a1fa5e4 --- /dev/null +++ b/src/crawlee/_utils/http.py @@ -0,0 +1,41 @@ +"""HTTP utility functions for Crawlee.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + + +def parse_retry_after_header(value: str | None) -> timedelta | None: + """Parse the Retry-After HTTP header value. + + The header can contain either a number of seconds or an HTTP-date. + See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + + Args: + value: The raw Retry-After header value. + + Returns: + A timedelta representing the delay, or None if the header is missing or unparsable. + """ + if not value: + return None + + # Try parsing as integer seconds first. + try: + seconds = int(value) + return timedelta(seconds=seconds) + except ValueError: + pass + + # Try parsing as HTTP-date (e.g., "Wed, 21 Oct 2015 07:28:00 GMT"). + from email.utils import parsedate_to_datetime # noqa: PLC0415 + + try: + retry_date = parsedate_to_datetime(value) + delay = retry_date - datetime.now(retry_date.tzinfo or timezone.utc) + if delay.total_seconds() > 0: + return delay + except (ValueError, TypeError): + pass + + return None diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index 7aafa49e2e..13bf2bd1b1 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -279,7 +279,12 @@ async def _handle_status_code_response( """ status_code = context.http_response.status_code if self._retry_on_blocked: - self._raise_for_session_blocked_status_code(context.session, status_code) + self._raise_for_session_blocked_status_code( + context.session, + status_code, + request_url=context.request.url, + retry_after_header=context.http_response.headers.get('retry-after'), + ) self._raise_for_error_status_code(status_code) yield context diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 6451d59461..b514e1e5a5 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -45,6 +45,7 @@ ) from crawlee._utils.docs import docs_group from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream +from crawlee._utils.http import parse_retry_after_header from crawlee._utils.recurring_task import RecurringTask from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute @@ -63,6 +64,7 @@ ) from crawlee.events._types import Event, EventCrawlerStatusData from crawlee.http_clients import ImpitHttpClient +from crawlee.request_loaders import ThrottlingRequestManager from crawlee.router import Router from crawlee.sessions import SessionPool from crawlee.statistics import Statistics, StatisticsState @@ -611,12 +613,17 @@ async def _get_proxy_info(self, request: Request, session: Session | None) -> Pr ) async def get_request_manager(self) -> RequestManager: - """Return the configured request manager. If none is configured, open and return the default request queue.""" + """Return the configured request manager. If none is configured, open and return the default request queue. + + The returned manager is wrapped with `ThrottlingRequestManager` to enforce + per-domain delays from 429 responses and robots.txt crawl-delay directives. + """ if not self._request_manager: - self._request_manager = await RequestQueue.open( + inner = await RequestQueue.open( storage_client=self._service_locator.get_storage_client(), configuration=self._service_locator.get_configuration(), ) + self._request_manager = ThrottlingRequestManager(inner) return self._request_manager @@ -707,12 +714,21 @@ async def run( await self._session_pool.reset_store() request_manager = await self.get_request_manager() - if purge_request_queue and isinstance(request_manager, RequestQueue): - await request_manager.drop() - self._request_manager = await RequestQueue.open( - storage_client=self._service_locator.get_storage_client(), - configuration=self._service_locator.get_configuration(), - ) + if purge_request_queue: + if isinstance(request_manager, RequestQueue): + await request_manager.drop() + self._request_manager = await RequestQueue.open( + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + elif isinstance(request_manager, ThrottlingRequestManager): + await request_manager.drop() + inner = await RequestQueue.open( + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + self._throttling_manager = ThrottlingRequestManager(inner) + self._request_manager = self._throttling_manager if requests is not None: await self.add_requests(requests) @@ -1442,6 +1458,10 @@ async def __run_task_function(self) -> None: await self._mark_request_as_handled(request) + # Record successful request to reset rate limit backoff for this domain. + if isinstance(request_manager, ThrottlingRequestManager): + request_manager.record_success(request.url) + if session and session.is_usable: session.mark_good() @@ -1542,22 +1562,43 @@ def _raise_for_error_status_code(self, status_code: int) -> None: if is_status_code_server_error(status_code) and not is_ignored_status: raise HttpStatusCodeError('Error status code returned', status_code) - def _raise_for_session_blocked_status_code(self, session: Session | None, status_code: int) -> None: + def _raise_for_session_blocked_status_code( + self, + session: Session | None, + status_code: int, + *, + request_url: str = '', + retry_after_header: str | None = None, + ) -> None: """Raise an exception if the given status code indicates the session is blocked. + If the status code is 429 (Too Many Requests), the domain is recorded as + rate-limited in the `ThrottlingRequestManager` for per-domain backoff. + Args: session: The session used for the request. If None, no check is performed. status_code: The HTTP status code to check. + request_url: The request URL, used for per-domain rate limit tracking. + retry_after_header: The value of the Retry-After response header, if present. Raises: SessionError: If the status code indicates the session is blocked. """ + if status_code == 429 and request_url: # noqa: PLR2004 + retry_after = parse_retry_after_header(retry_after_header) + + # _request_manager might not be initialized yet if called directly or early, + # but usually it's set in get_request_manager(). + if isinstance(self._request_manager, ThrottlingRequestManager): + self._request_manager.record_domain_delay(request_url, retry_after=retry_after) + if session is not None and session.is_blocked_status_code( status_code=status_code, ignore_http_error_status_codes=self._ignore_http_error_status_codes, ): raise SessionError(f'Assuming the session is blocked based on HTTP status code {status_code}') + # NOTE: _parse_retry_after_header has been moved to crawlee._utils.http.parse_retry_after_header def _check_request_collision(self, request: Request, session: Session | None) -> None: """Raise an exception if a request cannot access required resources. @@ -1582,7 +1623,16 @@ async def _is_allowed_based_on_robots_txt_file(self, url: str) -> bool: if not self._respect_robots_txt_file: return True robots_txt_file = await self._get_robots_txt_file_for_url(url) - return not robots_txt_file or robots_txt_file.is_allowed(url) + if not robots_txt_file: + return True + + # Wire robots.txt crawl-delay into ThrottlingRequestManager + if isinstance(self._request_manager, ThrottlingRequestManager): + crawl_delay = robots_txt_file.get_crawl_delay() + if crawl_delay is not None: + self._request_manager.set_crawl_delay(url, crawl_delay) + + return robots_txt_file.is_allowed(url) async def _get_robots_txt_file_for_url(self, url: str) -> RobotsTxtFile | None: """Get the RobotsTxtFile for a given URL. diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index 6f4b2b0e9d..9db340c976 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -459,7 +459,13 @@ async def _handle_status_code_response( """ status_code = context.response.status if self._retry_on_blocked: - self._raise_for_session_blocked_status_code(context.session, status_code) + retry_after_header = context.response.headers.get('retry-after') + self._raise_for_session_blocked_status_code( + context.session, + status_code, + request_url=context.request.url, + retry_after_header=retry_after_header, + ) self._raise_for_error_status_code(status_code) yield context diff --git a/src/crawlee/request_loaders/__init__.py b/src/crawlee/request_loaders/__init__.py index c04d9aa810..6dd8cccfab 100644 --- a/src/crawlee/request_loaders/__init__.py +++ b/src/crawlee/request_loaders/__init__.py @@ -3,5 +3,13 @@ from ._request_manager import RequestManager from ._request_manager_tandem import RequestManagerTandem from ._sitemap_request_loader import SitemapRequestLoader +from ._throttling_request_manager import ThrottlingRequestManager -__all__ = ['RequestList', 'RequestLoader', 'RequestManager', 'RequestManagerTandem', 'SitemapRequestLoader'] +__all__ = [ + 'RequestList', + 'RequestLoader', + 'RequestManager', + 'RequestManagerTandem', + 'SitemapRequestLoader', + 'ThrottlingRequestManager', +] diff --git a/src/crawlee/request_loaders/_throttling_request_manager.py b/src/crawlee/request_loaders/_throttling_request_manager.py new file mode 100644 index 0000000000..b2bc40eeb4 --- /dev/null +++ b/src/crawlee/request_loaders/_throttling_request_manager.py @@ -0,0 +1,358 @@ +"""A request manager wrapper that enforces per-domain delays. + +Handles both HTTP 429 backoff and robots.txt crawl-delay at the scheduling layer, +eliminating the busy-wait problem described in https://github.com/apify/crawlee-python/issues/1437. + +Also addresses https://github.com/apify/crawlee-python/issues/1396 by providing a unified +delay mechanism for crawl-delay directives. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from logging import getLogger +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from typing_extensions import override + +from crawlee._utils.docs import docs_group +from crawlee.request_loaders._request_manager import RequestManager +from crawlee.storages import RequestQueue + +if TYPE_CHECKING: + from collections.abc import Sequence + + from crawlee._request import Request + from crawlee.storage_clients.models import ProcessedRequest + +logger = getLogger(__name__) + + +@dataclass +class _DomainState: + """Tracks delay state for a single domain.""" + + domain: str + """The domain being tracked.""" + + throttled_until: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """Earliest time the next request to this domain is allowed.""" + + consecutive_429_count: int = 0 + """Number of consecutive 429 responses (for exponential backoff).""" + + crawl_delay: timedelta | None = None + """Minimum interval between requests from robots.txt crawl-delay directive.""" + + last_request_at: datetime | None = None + """When the last request to this domain was dispatched.""" + + +@docs_group('Request loaders') +class ThrottlingRequestManager(RequestManager): + """A request manager that wraps another and enforces per-domain delays. + + This moves throttling logic into the scheduling layer instead of the execution + layer. When `fetch_next_request()` is called, it intelligently handles delays: + + - If the next request's domain is not throttled, it returns immediately. + - If the domain is throttled but other requests are available, it moves the + throttled request to a per-domain sub-queue and tries the next one. + - If all available requests are throttled, it `asyncio.sleep()`s until the + earliest domain cooldown expires — eliminating busy-wait and unnecessary + queue writes. + + Delay sources: + - HTTP 429 responses (via `record_domain_delay`) + - robots.txt crawl-delay directives (via `set_crawl_delay`) + """ + + _BASE_DELAY = timedelta(seconds=2) + """Initial delay after the first 429 response from a domain.""" + + _MAX_DELAY = timedelta(seconds=60) + """Maximum delay between requests to a rate-limited domain.""" + + def __init__(self, inner: RequestManager) -> None: + """Initialize the throttling manager. + + Args: + inner: The underlying request manager to wrap (typically a RequestQueue). + """ + self._inner = inner + self._domain_states: dict[str, _DomainState] = {} + self._sub_queues: dict[str, RequestQueue] = {} + self._dispatched_origins: dict[str, str] = {} + self._transferred_requests_count = 0 + + @staticmethod + def _extract_domain(url: str) -> str: + """Extract the domain (hostname) from a URL.""" + parsed = urlparse(url) + return parsed.hostname or '' + + async def _get_or_create_sub_queue(self, domain: str) -> RequestQueue: + """Get or create a per-domain sub-queue.""" + if domain not in self._sub_queues: + self._sub_queues[domain] = await RequestQueue.open(alias=f'throttled-{domain}') + return self._sub_queues[domain] + + def _is_domain_throttled(self, domain: str) -> bool: + """Check if a domain is currently throttled.""" + state = self._domain_states.get(domain) + if state is None: + return False + + now = datetime.now(timezone.utc) + + # Check 429 backoff. + if now < state.throttled_until: + return True + + # Check crawl-delay: enforce minimum interval between requests. + return ( + state.crawl_delay is not None + and state.last_request_at is not None + and now < state.last_request_at + state.crawl_delay + ) + + def _get_earliest_available_time(self) -> datetime: + """Get the earliest time any throttled domain becomes available.""" + now = datetime.now(timezone.utc) + earliest = now + self._MAX_DELAY # Fallback upper bound. + + for state in self._domain_states.values(): + # Consider 429 backoff. + if state.throttled_until > now and state.throttled_until < earliest: + earliest = state.throttled_until + + # Consider crawl-delay. + if state.crawl_delay is not None and state.last_request_at is not None: + next_allowed = state.last_request_at + state.crawl_delay + if next_allowed > now and next_allowed < earliest: + earliest = next_allowed + + return earliest + + def record_domain_delay(self, url: str, *, retry_after: timedelta | None = None) -> None: + """Record a 429 Too Many Requests response for the domain of the given URL. + + Increments the consecutive 429 count and calculates the next allowed + request time using exponential backoff or the Retry-After value. + + Args: + url: The URL that received a 429 response. + retry_after: Optional delay from the Retry-After header. If provided, + it takes priority over the calculated exponential backoff. + """ + domain = self._extract_domain(url) + if not domain: + return + + now = datetime.now(timezone.utc) + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] + state.consecutive_429_count += 1 + + # Calculate delay: use Retry-After if provided, otherwise exponential backoff. + delay = retry_after if retry_after is not None else self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) + + # Cap the delay. + delay = min(delay, self._MAX_DELAY) + + state.throttled_until = now + delay + + logger.info( + f'Rate limit (429) detected for domain "{domain}" ' + f'(consecutive: {state.consecutive_429_count}, delay: {delay.total_seconds():.1f}s)' + ) + + def record_success(self, url: str) -> None: + """Record a successful request, resetting the backoff state for that domain. + + Args: + url: The URL that received a successful response. + """ + domain = self._extract_domain(url) + state = self._domain_states.get(domain) + + if state is not None and state.consecutive_429_count > 0: + logger.debug(f'Resetting rate limit state for domain "{domain}" after successful request') + state.consecutive_429_count = 0 + + def set_crawl_delay(self, url: str, delay_seconds: int) -> None: + """Set the robots.txt crawl-delay for a domain. + + Args: + url: A URL from the domain to throttle. + delay_seconds: The crawl-delay value in seconds. + """ + domain = self._extract_domain(url) + if not domain: + return + + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] + state.crawl_delay = timedelta(seconds=delay_seconds) + + logger.debug(f'Set crawl-delay for domain "{domain}" to {delay_seconds}s') + + def _mark_domain_dispatched(self, url: str) -> None: + """Record that a request to this domain was just dispatched.""" + domain = self._extract_domain(url) + if domain: + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] + state.last_request_at = datetime.now(timezone.utc) + + # ────────────────────────────────────────────────────── + # RequestManager interface delegation + smart scheduling + # ────────────────────────────────────────────────────── + + @override + async def drop(self) -> None: + await self._inner.drop() + for sq in self._sub_queues.values(): + await sq.drop() + self._sub_queues.clear() + self._dispatched_origins.clear() + + @override + async def add_request(self, request: str | Request, *, forefront: bool = False) -> ProcessedRequest: + return await self._inner.add_request(request, forefront=forefront) + + @override + async def add_requests( + self, + requests: Sequence[str | Request], + *, + forefront: bool = False, + batch_size: int = 1000, + wait_time_between_batches: timedelta = timedelta(seconds=1), + wait_for_all_requests_to_be_added: bool = False, + wait_for_all_requests_to_be_added_timeout: timedelta | None = None, + ) -> None: + return await self._inner.add_requests( + requests, + forefront=forefront, + batch_size=batch_size, + wait_time_between_batches=wait_time_between_batches, + wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, + wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, + ) + + @override + async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: + origin = self._dispatched_origins.get(request.unique_key) + if origin and origin != 'inner' and origin in self._sub_queues: + return await self._sub_queues[origin].reclaim_request(request, forefront=forefront) + return await self._inner.reclaim_request(request, forefront=forefront) + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + origin = self._dispatched_origins.get(request.unique_key) + if origin and origin != 'inner' and origin in self._sub_queues: + return await self._sub_queues[origin].mark_request_as_handled(request) + return await self._inner.mark_request_as_handled(request) + + @override + async def get_handled_count(self) -> int: + count = await self._inner.get_handled_count() + for sq in self._sub_queues.values(): + count += await sq.get_handled_count() + return max(0, count - self._transferred_requests_count) + + @override + async def get_total_count(self) -> int: + count = await self._inner.get_total_count() + for sq in self._sub_queues.values(): + count += await sq.get_total_count() + return max(0, count - self._transferred_requests_count) + + @override + async def is_empty(self) -> bool: + if not await self._inner.is_empty(): + return False + for sq in self._sub_queues.values(): + if not await sq.is_empty(): + return False + return True + + @override + async def is_finished(self) -> bool: + if not await self._inner.is_finished(): + return False + for sq in self._sub_queues.values(): + if not await sq.is_finished(): + return False + return True + + @override + async def fetch_next_request(self) -> Request | None: + """Fetch the next request, respecting per-domain delays. + + If the next available request belongs to a throttled domain, it is moved to + a per-domain sub-queue. If all unhandled requests are in throttled sub-queues, + it sleeps until the earliest domain becomes available. + """ + # First, check if any sub-queues have unthrottled requests. + for domain, sq in self._sub_queues.items(): + if not self._is_domain_throttled(domain): + req = await sq.fetch_next_request() + if req: + self._mark_domain_dispatched(req.url) + self._dispatched_origins[req.unique_key] = domain + return req + + # Try fetching from the inner queue. + while True: + request = await self._inner.fetch_next_request() + + if request is None: + # No more requests in inner queue. Check if sub-queues have requests. + have_sq_requests = False + for sq in self._sub_queues.values(): + if not await sq.is_empty(): + have_sq_requests = True + break + + if have_sq_requests: + # Requests exist but domains are throttled. Sleep and retry. + earliest = self._get_earliest_available_time() + sleep_duration = max( + (earliest - datetime.now(timezone.utc)).total_seconds(), + 0.1, # Minimum sleep to avoid tight loops. + ) + logger.debug( + f'Throttled sub-queues have requests. ' + f'Sleeping {sleep_duration:.1f}s until earliest domain is available.' + ) + await asyncio.sleep(sleep_duration) + return await self.fetch_next_request() + return None + + domain = self._extract_domain(request.url) + + if not self._is_domain_throttled(domain): + # Domain is clear — dispatch immediately. + self._mark_domain_dispatched(request.url) + self._dispatched_origins[request.unique_key] = 'inner' + return request + + # Domain is throttled — move this request to the sub-queue. + logger.debug( + f'Request to {request.url} moved to sub-queue — domain "{domain}" is throttled' + ) + sq = await self._get_or_create_sub_queue(domain) + await sq.add_request(request) + + # Mark it handled in inner so it is not processed twice. + # We track transfer count to correct get_total_count and get_handled_count. + await self._inner.mark_request_as_handled(request) + self._transferred_requests_count += 1 diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 23ca3c1eca..84c216066f 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -28,7 +28,7 @@ from crawlee.errors import RequestCollisionError, SessionError, UserDefinedErrorHandlerError from crawlee.events import Event, EventCrawlerStatusData from crawlee.events._local_event_manager import LocalEventManager -from crawlee.request_loaders import RequestList, RequestManagerTandem +from crawlee.request_loaders import RequestList, RequestManagerTandem, ThrottlingRequestManager from crawlee.sessions import Session, SessionPool from crawlee.statistics import FinalStatistics from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient @@ -653,7 +653,7 @@ async def test_crawler_get_storages() -> None: crawler = BasicCrawler() rp = await crawler.get_request_manager() - assert isinstance(rp, RequestQueue) + assert isinstance(rp, ThrottlingRequestManager) dataset = await crawler.get_dataset() assert isinstance(dataset, Dataset) @@ -1238,7 +1238,8 @@ async def test_crawler_uses_default_storages(tmp_path: Path) -> None: assert dataset is await crawler.get_dataset() assert kvs is await crawler.get_key_value_store() - assert rq is await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert (manager._inner if isinstance(manager, ThrottlingRequestManager) else manager) is rq async def test_crawler_can_use_other_storages(tmp_path: Path) -> None: @@ -1256,7 +1257,8 @@ async def test_crawler_can_use_other_storages(tmp_path: Path) -> None: assert dataset is not await crawler.get_dataset() assert kvs is not await crawler.get_key_value_store() - assert rq is not await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert (manager._inner if isinstance(manager, ThrottlingRequestManager) else manager) is not rq async def test_crawler_can_use_other_storages_of_same_type(tmp_path: Path) -> None: @@ -1293,7 +1295,8 @@ async def test_crawler_can_use_other_storages_of_same_type(tmp_path: Path) -> No # Assert that the storages are different assert dataset is not await crawler.get_dataset() assert kvs is not await crawler.get_key_value_store() - assert rq is not await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert (manager._inner if isinstance(manager, ThrottlingRequestManager) else manager) is not rq # Assert that all storages exists on the filesystem for path in expected_paths: diff --git a/tests/unit/test_throttling_request_manager.py b/tests/unit/test_throttling_request_manager.py new file mode 100644 index 0000000000..307aaa8fe2 --- /dev/null +++ b/tests/unit/test_throttling_request_manager.py @@ -0,0 +1,369 @@ +"""Tests for ThrottlingRequestManager - per-domain delay scheduling.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch + +import pytest + +from crawlee._request import Request +from crawlee._utils.http import parse_retry_after_header +from crawlee.request_loaders._throttling_request_manager import ThrottlingRequestManager + + +@pytest.fixture +def mock_inner() -> AsyncMock: + """Create a mock RequestManager to wrap.""" + inner = AsyncMock() + inner.fetch_next_request = AsyncMock(return_value=None) + inner.add_request = AsyncMock() + inner.add_requests = AsyncMock() + inner.reclaim_request = AsyncMock() + inner.mark_request_as_handled = AsyncMock() + inner.get_handled_count = AsyncMock(return_value=0) + inner.get_total_count = AsyncMock(return_value=0) + inner.is_empty = AsyncMock(return_value=True) + inner.is_finished = AsyncMock(return_value=True) + inner.drop = AsyncMock() + return inner + + +@pytest.fixture +def manager(mock_inner: AsyncMock) -> ThrottlingRequestManager: + """Create a ThrottlingRequestManager wrapping the mock.""" + return ThrottlingRequestManager(mock_inner) + + +@pytest.fixture(autouse=True) +def mock_request_queue_open() -> AsyncMock: + """Mock RequestQueue.open to avoid hitting real storage during tests.""" + target = 'crawlee.request_loaders._throttling_request_manager.RequestQueue.open' + with patch(target, new_callable=AsyncMock) as mocked: + async def mock_open(*args: any, **kwargs: any) -> AsyncMock: # noqa: ARG001 + sq = AsyncMock() + sq.fetch_next_request = AsyncMock(return_value=None) + sq.add_request = AsyncMock() + sq.reclaim_request = AsyncMock() + sq.mark_request_as_handled = AsyncMock() + sq.get_handled_count = AsyncMock(return_value=0) + sq.get_total_count = AsyncMock(return_value=0) + sq.is_empty = AsyncMock(return_value=True) + sq.is_finished = AsyncMock(return_value=True) + sq.drop = AsyncMock() + return sq + + mocked.side_effect = mock_open + yield mocked + + +def _make_request(url: str) -> Request: + """Helper to create a Request object.""" + return Request.from_url(url) + + +# ── Core Throttling Tests ───────────────────────────────── + + +@pytest.mark.asyncio +async def test_non_throttled_passes_through(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """Requests for non-throttled domains should return immediately.""" + request = _make_request('https://example.com/page1') + mock_inner.fetch_next_request.return_value = request + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == 'https://example.com/page1' + + +@pytest.mark.asyncio +async def test_429_triggers_domain_delay(manager: ThrottlingRequestManager) -> None: + """After record_domain_delay(), the domain should be throttled.""" + manager.record_domain_delay('https://example.com/page1') + + assert manager._is_domain_throttled('example.com') + + +@pytest.mark.asyncio +async def test_different_domains_independent(manager: ThrottlingRequestManager) -> None: + """Throttling example.com should NOT affect other-site.com.""" + manager.record_domain_delay('https://example.com/page1') + + assert manager._is_domain_throttled('example.com') + assert not manager._is_domain_throttled('other-site.com') + + +@pytest.mark.asyncio +async def test_exponential_backoff(manager: ThrottlingRequestManager) -> None: + """Consecutive 429s should increase delay exponentially.""" + url = 'https://example.com/page1' + + manager.record_domain_delay(url) + state = manager._domain_states['example.com'] + first_until = state.throttled_until + + manager.record_domain_delay(url) + second_until = state.throttled_until + + assert second_until > first_until + assert state.consecutive_429_count == 2 + + +@pytest.mark.asyncio +async def test_max_delay_cap(manager: ThrottlingRequestManager) -> None: + """Backoff should cap at _MAX_DELAY (60s).""" + url = 'https://example.com/page1' + + for _ in range(20): + manager.record_domain_delay(url) + + state = manager._domain_states['example.com'] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now + + assert actual_delay <= manager._MAX_DELAY + timedelta(seconds=1) + + +@pytest.mark.asyncio +async def test_retry_after_header_priority(manager: ThrottlingRequestManager) -> None: + """Explicit Retry-After should override exponential backoff.""" + url = 'https://example.com/page1' + + manager.record_domain_delay(url, retry_after=timedelta(seconds=30)) + + state = manager._domain_states['example.com'] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now + + assert actual_delay > timedelta(seconds=28) + assert actual_delay <= timedelta(seconds=31) + + +@pytest.mark.asyncio +async def test_success_resets_backoff(manager: ThrottlingRequestManager) -> None: + """Successful request should reset the consecutive 429 count.""" + url = 'https://example.com/page1' + + manager.record_domain_delay(url) + manager.record_domain_delay(url) + assert manager._domain_states['example.com'].consecutive_429_count == 2 + + manager.record_success(url) + assert manager._domain_states['example.com'].consecutive_429_count == 0 + + +# ── Crawl-Delay Integration Tests ───────────────────────── + + +@pytest.mark.asyncio +async def test_crawl_delay_integration(manager: ThrottlingRequestManager) -> None: + """set_crawl_delay() should enforce per-domain minimum interval.""" + url = 'https://example.com/page1' + manager.set_crawl_delay(url, 5) + + state = manager._domain_states['example.com'] + assert state.crawl_delay == timedelta(seconds=5) + + +@pytest.mark.asyncio +async def test_crawl_delay_throttles_after_dispatch(manager: ThrottlingRequestManager) -> None: + """After dispatching a request, crawl-delay should throttle the next one.""" + url = 'https://example.com/page1' + manager.set_crawl_delay(url, 5) + + manager._mark_domain_dispatched(url) + + assert manager._is_domain_throttled('example.com') + + +# ── Sleep-Based Scheduling Tests ──────────────────────── + + +@pytest.mark.asyncio +async def test_mixed_throttled_and_unthrottled( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, + mock_request_queue_open: AsyncMock, +) -> None: + """Throttled domain requests should be moved to sub-queues; unthrottled ones returned.""" + throttled_req = _make_request('https://throttled.com/page1') + unthrottled_req = _make_request('https://free.com/page1') + + manager.record_domain_delay('https://throttled.com/page1') + + # inner returns throttled, then unthrottled + mock_inner.fetch_next_request.side_effect = [throttled_req, unthrottled_req] + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == 'https://free.com/page1' + + # Verify throttled request was moved to sub-queue + mock_request_queue_open.assert_called_once() + assert 'throttled.com' in manager._sub_queues + + sq = manager._sub_queues['throttled.com'] + sq.add_request.assert_called_once_with(throttled_req) + + +@pytest.mark.asyncio +async def test_sleep_instead_of_busy_wait(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """When all domains are throttled and queue is empty, should sleep (not spin).""" + throttled_req = _make_request('https://throttled.com/page1') + + manager.record_domain_delay('https://throttled.com/page1', retry_after=timedelta(seconds=0.2)) + + # inner queue returns the request first time, then None + mock_inner.fetch_next_request.side_effect = [throttled_req, None] + + target = 'crawlee.request_loaders._throttling_request_manager.asyncio.sleep' + with patch(target, new_callable=AsyncMock) as mock_sleep: + # Instead of actually sleeping, we simulate the time passing by unthrottling the domain + async def sleep_side_effect(*args: any, **kwargs: any) -> None: # noqa: ARG001 + # Clear throttle so recursive call succeeds + manager._domain_states['throttled.com'].throttled_until = datetime.now(timezone.utc) + # Setup the sub-queue to return the request now + sq = manager._sub_queues['throttled.com'] + sq.fetch_next_request.side_effect = [throttled_req, None] + # Must return False then True so loop proceeds + sq.is_empty.side_effect = [False, True] + + mock_sleep.side_effect = sleep_side_effect + + # When request is moved to sub-queue, we must ensure it isn't "empty" so it triggers sleep + async def mock_add_request(*args: any, **kwargs: any) -> None: + sq = manager._sub_queues['throttled.com'] + sq.is_empty.return_value = False + manager._sub_queues = {'throttled.com': AsyncMock()} + manager._sub_queues['throttled.com'].add_request.side_effect = mock_add_request + manager._sub_queues['throttled.com'].is_empty.return_value = True + + result = await manager.fetch_next_request() + + mock_sleep.assert_called_once() + assert result is not None + assert result.url == 'https://throttled.com/page1' + + +# ── Delegation Tests ──────────────────────────────────── + + +@pytest.mark.asyncio +async def test_add_request_delegates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.add_request(request) + mock_inner.add_request.assert_called_once_with(request, forefront=False) + + +@pytest.mark.asyncio +async def test_reclaim_request_delegates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.reclaim_request(request) + mock_inner.reclaim_request.assert_called_once_with(request, forefront=False) + + +@pytest.mark.asyncio +async def test_reclaim_request_delegates_to_sub_queue(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + + # Setup state manually assuming it was fetched from a sub-queue + sq = AsyncMock() + manager._sub_queues['example.com'] = sq + manager._dispatched_origins[request.unique_key] = 'example.com' + + await manager.reclaim_request(request) + + sq.reclaim_request.assert_called_once_with(request, forefront=False) + mock_inner.reclaim_request.assert_not_called() + + +@pytest.mark.asyncio +async def test_mark_request_as_handled_delegates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.mark_request_as_handled(request) + mock_inner.mark_request_as_handled.assert_called_once_with(request) + + +@pytest.mark.asyncio +async def test_get_handled_count_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.get_handled_count.return_value = 42 + + sq = AsyncMock() + sq.get_handled_count.return_value = 10 + manager._sub_queues['example.com'] = sq + manager._transferred_requests_count = 5 + + assert await manager.get_handled_count() == 47 + + +@pytest.mark.asyncio +async def test_get_total_count_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.get_total_count.return_value = 100 + + sq = AsyncMock() + sq.get_total_count.return_value = 20 + manager._sub_queues['example.com'] = sq + manager._transferred_requests_count = 10 + + assert await manager.get_total_count() == 110 + + +@pytest.mark.asyncio +async def test_is_empty_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.is_empty.return_value = True + assert await manager.is_empty() is True + + sq = AsyncMock() + sq.is_empty.return_value = False + manager._sub_queues['example.com'] = sq + + assert await manager.is_empty() is False + + +@pytest.mark.asyncio +async def test_is_finished_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.is_finished.return_value = True + assert await manager.is_finished() is True + + sq = AsyncMock() + sq.is_finished.return_value = False + manager._sub_queues['example.com'] = sq + + assert await manager.is_finished() is False + + +@pytest.mark.asyncio +async def test_drop_clears_all(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + sq = AsyncMock() + manager._sub_queues['example.com'] = sq + manager._dispatched_origins[request.unique_key] = 'inner' + + await manager.drop() + + mock_inner.drop.assert_called_once() + sq.drop.assert_called_once() + assert len(manager._sub_queues) == 0 + assert len(manager._dispatched_origins) == 0 + + +# ── Utility Tests ────────────────────────────────────── + + +def test_parse_retry_after_none_value() -> None: + assert parse_retry_after_header(None) is None + + +def test_parse_retry_after_empty_string() -> None: + assert parse_retry_after_header('') is None + + +def test_parse_retry_after_integer_seconds() -> None: + result = parse_retry_after_header('120') + assert result == timedelta(seconds=120) + + +def test_parse_retry_after_invalid_value() -> None: + assert parse_retry_after_header('not-a-date-or-number') is None