From 64f7247b6dfa6621743f47d3f56f55066cc92ab1 Mon Sep 17 00:00:00 2001 From: MrAliHasan Date: Sat, 21 Feb 2026 02:30:36 +0500 Subject: [PATCH 1/3] fix: add per-domain RequestThrottler for 429 backoff (#1437) Add a new RequestThrottler component that handles HTTP 429 (Too Many Requests) responses on a per-domain basis, preventing the autoscaling death spiral where 429s cause concurrency to increase. Key features: - Per-domain tracking: rate limiting on domain A doesn't affect domain B - Exponential backoff: 2s -> 4s -> 8s -> ... capped at 60s - Retry-After header support (both seconds and HTTP-date formats) - Throttled requests are reclaimed to the queue, not dropped - Backoff resets on successful requests to that domain The AutoscaledPool is completely untouched - throttling happens transparently in BasicCrawler.__run_task_function before processing. Integration points: - BasicCrawler: throttle check, 429 recording, success reset - AbstractHttpCrawler: passes URL + Retry-After to detection - PlaywrightCrawler: passes URL + Retry-After to detection Closes #1437 --- src/crawlee/_request_throttler.py | 150 ++++++++++++++++++ .../_abstract_http/_abstract_http_crawler.py | 7 +- src/crawlee/crawlers/_basic/_basic_crawler.py | 70 +++++++- .../_playwright/_playwright_crawler.py | 8 +- tests/unit/test_request_throttler.py | 148 +++++++++++++++++ 5 files changed, 379 insertions(+), 4 deletions(-) create mode 100644 src/crawlee/_request_throttler.py create mode 100644 tests/unit/test_request_throttler.py diff --git a/src/crawlee/_request_throttler.py b/src/crawlee/_request_throttler.py new file mode 100644 index 0000000000..dbe43db283 --- /dev/null +++ b/src/crawlee/_request_throttler.py @@ -0,0 +1,150 @@ +# Per-domain rate limit tracker for handling HTTP 429 responses. +# See: https://github.com/apify/crawlee-python/issues/1437 + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from logging import getLogger +from urllib.parse import urlparse + +from crawlee._utils.docs import docs_group + +logger = getLogger(__name__) + + +@dataclass +class _DomainState: + """Tracks rate limit state for a single domain.""" + + domain: str + """The domain being tracked.""" + + next_allowed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """Earliest time the next request to this domain is allowed.""" + + consecutive_429_count: int = 0 + """Number of consecutive 429 responses (for exponential backoff).""" + + +@docs_group('Crawlers') +class RequestThrottler: + """Per-domain rate limit tracker and request throttler. + + When a target website returns HTTP 429 (Too Many Requests), this component + tracks the rate limit event per domain and applies exponential backoff. + Requests to other (non-rate-limited) domains are unaffected. + + This solves the "death spiral" problem where 429 responses reduce CPU usage, + causing the `AutoscaledPool` to incorrectly scale UP concurrency. + """ + + _BASE_DELAY = timedelta(seconds=2) + """Initial delay after the first 429 response from a domain.""" + + _MAX_DELAY = timedelta(seconds=60) + """Maximum delay between requests to a rate-limited domain.""" + + def __init__(self) -> None: + self._domain_states: dict[str, _DomainState] = {} + + @staticmethod + def _extract_domain(url: str) -> str: + """Extract the domain (hostname) from a URL. + + Args: + url: The URL to extract the domain from. + + Returns: + The hostname portion of the URL, or an empty string if parsing fails. + """ + parsed = urlparse(url) + return parsed.hostname or '' + + def record_rate_limit(self, url: str, *, retry_after: timedelta | None = None) -> None: + """Record a 429 Too Many Requests response for the domain of the given URL. + + Increments the consecutive 429 count and calculates the next allowed + request time using exponential backoff or the Retry-After value. + + Args: + url: The URL that received a 429 response. + retry_after: Optional delay from the Retry-After header. If provided, + it takes priority over the calculated exponential backoff. + """ + domain = self._extract_domain(url) + if not domain: + return + + now = datetime.now(timezone.utc) + + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + + state = self._domain_states[domain] + state.consecutive_429_count += 1 + + # Calculate delay: use Retry-After if provided, otherwise exponential backoff. + if retry_after is not None: + delay = retry_after + else: + delay = self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) + + # Cap the delay at _MAX_DELAY. + if delay > self._MAX_DELAY: + delay = self._MAX_DELAY + + state.next_allowed_at = now + delay + + logger.info( + f'Rate limit (429) detected for domain "{domain}" ' + f'(consecutive: {state.consecutive_429_count}, delay: {delay.total_seconds():.1f}s)' + ) + + def is_throttled(self, url: str) -> bool: + """Check if requests to the domain of the given URL should be delayed. + + Args: + url: The URL to check. + + Returns: + True if the domain is currently rate-limited and the cooldown has not expired. + """ + domain = self._extract_domain(url) + state = self._domain_states.get(domain) + + if state is None: + return False + + return datetime.now(timezone.utc) < state.next_allowed_at + + def get_delay(self, url: str) -> timedelta: + """Get the remaining delay before the next request to this domain is allowed. + + Args: + url: The URL to check. + + Returns: + The remaining time to wait. Returns zero if no delay is needed. + """ + domain = self._extract_domain(url) + state = self._domain_states.get(domain) + + if state is None: + return timedelta(0) + + remaining = state.next_allowed_at - datetime.now(timezone.utc) + return max(remaining, timedelta(0)) + + def record_success(self, url: str) -> None: + """Record a successful request to the domain, resetting its backoff state. + + Args: + url: The URL that received a successful response. + """ + domain = self._extract_domain(url) + state = self._domain_states.get(domain) + + if state is not None and state.consecutive_429_count > 0: + logger.debug(f'Resetting rate limit state for domain "{domain}" after successful request') + state.consecutive_429_count = 0 diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index 7aafa49e2e..6c598a3ae8 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -279,7 +279,12 @@ async def _handle_status_code_response( """ status_code = context.http_response.status_code if self._retry_on_blocked: - self._raise_for_session_blocked_status_code(context.session, status_code) + self._raise_for_session_blocked_status_code( + context.session, + status_code, + url=context.request.url, + retry_after_header=context.http_response.headers.get('retry-after'), + ) self._raise_for_error_status_code(status_code) yield context diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 6451d59461..7ac840be42 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -12,7 +12,7 @@ from asyncio import CancelledError from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence from contextlib import AsyncExitStack, suppress -from datetime import timedelta +from datetime import datetime, timedelta, timezone from functools import partial from io import StringIO from pathlib import Path @@ -46,6 +46,7 @@ from crawlee._utils.docs import docs_group from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream from crawlee._utils.recurring_task import RecurringTask +from crawlee._request_throttler import RequestThrottler from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -485,6 +486,7 @@ async def persist_state_factory() -> KeyValueStore: self._robots_txt_file_cache: LRUCache[str, RobotsTxtFile] = LRUCache(maxsize=1000) self._robots_txt_lock = asyncio.Lock() self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) + self._request_throttler = RequestThrottler() self._snapshotter = Snapshotter.from_config(config) self._autoscaled_pool = AutoscaledPool( system_status=SystemStatus(self._snapshotter), @@ -1396,6 +1398,15 @@ async def __run_task_function(self) -> None: if request is None: return + # Check if this domain is currently rate-limited (429 backoff). + if self._request_throttler.is_throttled(request.url): + self._logger.debug( + f'Request to {request.url} delayed - domain is rate-limited ' + f'(retry in {self._request_throttler.get_delay(request.url).total_seconds():.1f}s)' + ) + await request_manager.reclaim_request(request) + return + if not (await self._is_allowed_based_on_robots_txt_file(request.url)): self._logger.warning( f'Skipping request {request.url} ({request.unique_key}) because it is disallowed based on robots.txt' @@ -1442,6 +1453,9 @@ async def __run_task_function(self) -> None: await self._mark_request_as_handled(request) + # Record successful request to reset rate limit backoff for this domain. + self._request_throttler.record_success(request.url) + if session and session.is_usable: session.mark_good() @@ -1542,22 +1556,74 @@ def _raise_for_error_status_code(self, status_code: int) -> None: if is_status_code_server_error(status_code) and not is_ignored_status: raise HttpStatusCodeError('Error status code returned', status_code) - def _raise_for_session_blocked_status_code(self, session: Session | None, status_code: int) -> None: + def _raise_for_session_blocked_status_code( + self, + session: Session | None, + status_code: int, + *, + url: str = '', + retry_after_header: str | None = None, + ) -> None: """Raise an exception if the given status code indicates the session is blocked. + If the status code is 429 (Too Many Requests), the domain is recorded as + rate-limited in the `RequestThrottler` for per-domain backoff. + Args: session: The session used for the request. If None, no check is performed. status_code: The HTTP status code to check. + url: The request URL, used for per-domain rate limit tracking. + retry_after_header: The value of the Retry-After response header, if present. Raises: SessionError: If the status code indicates the session is blocked. """ + if status_code == 429 and url: + retry_after = self._parse_retry_after_header(retry_after_header) + self._request_throttler.record_rate_limit(url, retry_after=retry_after) + if session is not None and session.is_blocked_status_code( status_code=status_code, ignore_http_error_status_codes=self._ignore_http_error_status_codes, ): raise SessionError(f'Assuming the session is blocked based on HTTP status code {status_code}') + @staticmethod + def _parse_retry_after_header(value: str | None) -> timedelta | None: + """Parse the Retry-After HTTP header value. + + The header can contain either a number of seconds or an HTTP-date. + See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + + Args: + value: The raw Retry-After header value. + + Returns: + A timedelta representing the delay, or None if the header is missing or unparseable. + """ + if not value: + return None + + # Try parsing as integer seconds first. + try: + seconds = int(value) + return timedelta(seconds=seconds) + except ValueError: + pass + + # Try parsing as HTTP-date (e.g., "Wed, 21 Oct 2015 07:28:00 GMT"). + from email.utils import parsedate_to_datetime + + try: + retry_date = parsedate_to_datetime(value) + delay = retry_date - datetime.now(retry_date.tzinfo or timezone.utc) + if delay.total_seconds() > 0: + return delay + except (ValueError, TypeError): + pass + + return None + def _check_request_collision(self, request: Request, session: Session | None) -> None: """Raise an exception if a request cannot access required resources. diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index 6f4b2b0e9d..073247a16e 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -459,7 +459,13 @@ async def _handle_status_code_response( """ status_code = context.response.status if self._retry_on_blocked: - self._raise_for_session_blocked_status_code(context.session, status_code) + retry_after_header = context.response.headers.get('retry-after') + self._raise_for_session_blocked_status_code( + context.session, + status_code, + url=context.request.url, + retry_after_header=retry_after_header, + ) self._raise_for_error_status_code(status_code) yield context diff --git a/tests/unit/test_request_throttler.py b/tests/unit/test_request_throttler.py new file mode 100644 index 0000000000..d344b2d9f8 --- /dev/null +++ b/tests/unit/test_request_throttler.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest + +from crawlee._request_throttler import RequestThrottler + + +class TestRequestThrottler: + """Tests for the RequestThrottler per-domain rate limit tracker.""" + + def test_not_throttled_by_default(self) -> None: + """Requests should not be throttled when no 429 has been recorded.""" + throttler = RequestThrottler() + assert not throttler.is_throttled('https://example.com/page1') + assert throttler.get_delay('https://example.com/page1') == timedelta(0) + + def test_throttled_after_rate_limit(self) -> None: + """A domain should be throttled after a 429 is recorded.""" + throttler = RequestThrottler() + throttler.record_rate_limit('https://example.com/page1') + + assert throttler.is_throttled('https://example.com/page2') + assert throttler.get_delay('https://example.com/page2') > timedelta(0) + + def test_different_domains_independent(self) -> None: + """A 429 on domain A should not affect domain B.""" + throttler = RequestThrottler() + throttler.record_rate_limit('https://example.com/page1') + + # example.com should be throttled + assert throttler.is_throttled('https://example.com/other') + + # other-site.com should NOT be throttled + assert not throttler.is_throttled('https://other-site.com/page1') + assert throttler.get_delay('https://other-site.com/page1') == timedelta(0) + + def test_exponential_backoff(self) -> None: + """Consecutive 429s should increase delay exponentially.""" + throttler = RequestThrottler() + + # First 429: delay = 2s (BASE_DELAY * 2^0) + throttler.record_rate_limit('https://example.com/a') + delay_1 = throttler.get_delay('https://example.com/a') + + # Second 429: delay = 4s (BASE_DELAY * 2^1) + throttler.record_rate_limit('https://example.com/b') + delay_2 = throttler.get_delay('https://example.com/b') + + # Third 429: delay = 8s (BASE_DELAY * 2^2) + throttler.record_rate_limit('https://example.com/c') + delay_3 = throttler.get_delay('https://example.com/c') + + # Each subsequent delay should be roughly double the previous + assert delay_2 > delay_1 + assert delay_3 > delay_2 + + def test_max_delay_cap(self) -> None: + """Delay should be capped at MAX_DELAY even with many consecutive 429s.""" + throttler = RequestThrottler() + + # Record many 429s to exceed MAX_DELAY + for _ in range(20): + throttler.record_rate_limit('https://example.com/page') + + delay = throttler.get_delay('https://example.com/page') + assert delay <= RequestThrottler._MAX_DELAY + + def test_success_resets_backoff(self) -> None: + """A successful request should reset the consecutive 429 count.""" + throttler = RequestThrottler() + + # Record multiple 429s + throttler.record_rate_limit('https://example.com/a') + throttler.record_rate_limit('https://example.com/b') + throttler.record_rate_limit('https://example.com/c') + + # Record a success + throttler.record_success('https://example.com/page') + + # The internal state should show 0 consecutive 429s + state = throttler._domain_states.get('example.com') + assert state is not None + assert state.consecutive_429_count == 0 + + def test_retry_after_takes_priority(self) -> None: + """Retry-After value should take priority over exponential backoff.""" + throttler = RequestThrottler() + + # Record 429 with a specific Retry-After of 30 seconds + throttler.record_rate_limit('https://example.com/page', retry_after=timedelta(seconds=30)) + + delay = throttler.get_delay('https://example.com/page') + # Delay should be close to 30s (minus time elapsed since recording) + assert delay > timedelta(seconds=29) + assert delay <= timedelta(seconds=30) + + def test_throttle_expires_after_delay(self) -> None: + """A domain should no longer be throttled after the delay expires.""" + throttler = RequestThrottler() + + # Record a 429 and manually set next_allowed_at to the past + throttler.record_rate_limit('https://example.com/page') + state = throttler._domain_states['example.com'] + state.next_allowed_at = datetime.now(timezone.utc) - timedelta(seconds=1) + + assert not throttler.is_throttled('https://example.com/page') + assert throttler.get_delay('https://example.com/page') == timedelta(0) + + def test_empty_url_handling(self) -> None: + """Empty or invalid URLs should not cause errors.""" + throttler = RequestThrottler() + + # These should not raise + throttler.record_rate_limit('') + throttler.record_success('') + assert not throttler.is_throttled('') + + +class TestParseRetryAfterHeader: + """Tests for BasicCrawler._parse_retry_after_header.""" + + def test_none_value(self) -> None: + """None input returns None.""" + from crawlee.crawlers._basic._basic_crawler import BasicCrawler + + assert BasicCrawler._parse_retry_after_header(None) is None + + def test_empty_string(self) -> None: + """Empty string returns None.""" + from crawlee.crawlers._basic._basic_crawler import BasicCrawler + + assert BasicCrawler._parse_retry_after_header('') is None + + def test_integer_seconds(self) -> None: + """Integer value should be parsed as seconds.""" + from crawlee.crawlers._basic._basic_crawler import BasicCrawler + + result = BasicCrawler._parse_retry_after_header('120') + assert result == timedelta(seconds=120) + + def test_invalid_value(self) -> None: + """Invalid values should return None.""" + from crawlee.crawlers._basic._basic_crawler import BasicCrawler + + assert BasicCrawler._parse_retry_after_header('not-a-number') is None From 62ab3b8ae2b066737360c984cd1691882edb8872 Mon Sep 17 00:00:00 2001 From: MrAliHasan Date: Mon, 23 Feb 2026 22:48:18 +0500 Subject: [PATCH 2/3] refactor: replace RequestThrottler with ThrottlingRequestManager MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move per-domain throttling from execution layer (BasicCrawler.__run_task_function) to scheduling layer (ThrottlingRequestManager.fetch_next_request). - ThrottlingRequestManager wraps RequestQueue, implements RequestManager interface - fetch_next_request() buffers throttled requests and asyncio.sleep()s when all domains are throttled — eliminates busy-wait and unnecessary queue writes - Unified delay mechanism supports both HTTP 429 backoff and robots.txt crawl-delay (#1396) - parse_retry_after_header moved to crawlee._utils.http - 23 new tests covering throttling, scheduling, delegation, and crawl-delay Addresses #1437, #1396 --- src/crawlee/_request_throttler.py | 150 -------- src/crawlee/_utils/http.py | 41 +++ src/crawlee/crawlers/_basic/_basic_crawler.py | 79 ++-- src/crawlee/request_loaders/__init__.py | 10 +- .../_throttling_request_manager.py | 340 ++++++++++++++++++ tests/unit/test_request_throttler.py | 148 -------- tests/unit/test_throttling_request_manager.py | 323 +++++++++++++++++ 7 files changed, 740 insertions(+), 351 deletions(-) delete mode 100644 src/crawlee/_request_throttler.py create mode 100644 src/crawlee/_utils/http.py create mode 100644 src/crawlee/request_loaders/_throttling_request_manager.py delete mode 100644 tests/unit/test_request_throttler.py create mode 100644 tests/unit/test_throttling_request_manager.py diff --git a/src/crawlee/_request_throttler.py b/src/crawlee/_request_throttler.py deleted file mode 100644 index dbe43db283..0000000000 --- a/src/crawlee/_request_throttler.py +++ /dev/null @@ -1,150 +0,0 @@ -# Per-domain rate limit tracker for handling HTTP 429 responses. -# See: https://github.com/apify/crawlee-python/issues/1437 - -from __future__ import annotations - -from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone -from logging import getLogger -from urllib.parse import urlparse - -from crawlee._utils.docs import docs_group - -logger = getLogger(__name__) - - -@dataclass -class _DomainState: - """Tracks rate limit state for a single domain.""" - - domain: str - """The domain being tracked.""" - - next_allowed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - """Earliest time the next request to this domain is allowed.""" - - consecutive_429_count: int = 0 - """Number of consecutive 429 responses (for exponential backoff).""" - - -@docs_group('Crawlers') -class RequestThrottler: - """Per-domain rate limit tracker and request throttler. - - When a target website returns HTTP 429 (Too Many Requests), this component - tracks the rate limit event per domain and applies exponential backoff. - Requests to other (non-rate-limited) domains are unaffected. - - This solves the "death spiral" problem where 429 responses reduce CPU usage, - causing the `AutoscaledPool` to incorrectly scale UP concurrency. - """ - - _BASE_DELAY = timedelta(seconds=2) - """Initial delay after the first 429 response from a domain.""" - - _MAX_DELAY = timedelta(seconds=60) - """Maximum delay between requests to a rate-limited domain.""" - - def __init__(self) -> None: - self._domain_states: dict[str, _DomainState] = {} - - @staticmethod - def _extract_domain(url: str) -> str: - """Extract the domain (hostname) from a URL. - - Args: - url: The URL to extract the domain from. - - Returns: - The hostname portion of the URL, or an empty string if parsing fails. - """ - parsed = urlparse(url) - return parsed.hostname or '' - - def record_rate_limit(self, url: str, *, retry_after: timedelta | None = None) -> None: - """Record a 429 Too Many Requests response for the domain of the given URL. - - Increments the consecutive 429 count and calculates the next allowed - request time using exponential backoff or the Retry-After value. - - Args: - url: The URL that received a 429 response. - retry_after: Optional delay from the Retry-After header. If provided, - it takes priority over the calculated exponential backoff. - """ - domain = self._extract_domain(url) - if not domain: - return - - now = datetime.now(timezone.utc) - - if domain not in self._domain_states: - self._domain_states[domain] = _DomainState(domain=domain) - - state = self._domain_states[domain] - state.consecutive_429_count += 1 - - # Calculate delay: use Retry-After if provided, otherwise exponential backoff. - if retry_after is not None: - delay = retry_after - else: - delay = self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) - - # Cap the delay at _MAX_DELAY. - if delay > self._MAX_DELAY: - delay = self._MAX_DELAY - - state.next_allowed_at = now + delay - - logger.info( - f'Rate limit (429) detected for domain "{domain}" ' - f'(consecutive: {state.consecutive_429_count}, delay: {delay.total_seconds():.1f}s)' - ) - - def is_throttled(self, url: str) -> bool: - """Check if requests to the domain of the given URL should be delayed. - - Args: - url: The URL to check. - - Returns: - True if the domain is currently rate-limited and the cooldown has not expired. - """ - domain = self._extract_domain(url) - state = self._domain_states.get(domain) - - if state is None: - return False - - return datetime.now(timezone.utc) < state.next_allowed_at - - def get_delay(self, url: str) -> timedelta: - """Get the remaining delay before the next request to this domain is allowed. - - Args: - url: The URL to check. - - Returns: - The remaining time to wait. Returns zero if no delay is needed. - """ - domain = self._extract_domain(url) - state = self._domain_states.get(domain) - - if state is None: - return timedelta(0) - - remaining = state.next_allowed_at - datetime.now(timezone.utc) - return max(remaining, timedelta(0)) - - def record_success(self, url: str) -> None: - """Record a successful request to the domain, resetting its backoff state. - - Args: - url: The URL that received a successful response. - """ - domain = self._extract_domain(url) - state = self._domain_states.get(domain) - - if state is not None and state.consecutive_429_count > 0: - logger.debug(f'Resetting rate limit state for domain "{domain}" after successful request') - state.consecutive_429_count = 0 diff --git a/src/crawlee/_utils/http.py b/src/crawlee/_utils/http.py new file mode 100644 index 0000000000..e8f2249a98 --- /dev/null +++ b/src/crawlee/_utils/http.py @@ -0,0 +1,41 @@ +"""HTTP utility functions for Crawlee.""" + +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + + +def parse_retry_after_header(value: str | None) -> timedelta | None: + """Parse the Retry-After HTTP header value. + + The header can contain either a number of seconds or an HTTP-date. + See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + + Args: + value: The raw Retry-After header value. + + Returns: + A timedelta representing the delay, or None if the header is missing or unparseable. + """ + if not value: + return None + + # Try parsing as integer seconds first. + try: + seconds = int(value) + return timedelta(seconds=seconds) + except ValueError: + pass + + # Try parsing as HTTP-date (e.g., "Wed, 21 Oct 2015 07:28:00 GMT"). + from email.utils import parsedate_to_datetime # noqa: PLC0415 + + try: + retry_date = parsedate_to_datetime(value) + delay = retry_date - datetime.now(retry_date.tzinfo or timezone.utc) + if delay.total_seconds() > 0: + return delay + except (ValueError, TypeError): + pass + + return None diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 7ac840be42..32e1957017 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -46,7 +46,8 @@ from crawlee._utils.docs import docs_group from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream from crawlee._utils.recurring_task import RecurringTask -from crawlee._request_throttler import RequestThrottler +from crawlee._utils.http import parse_retry_after_header +from crawlee.request_loaders import ThrottlingRequestManager from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -486,7 +487,7 @@ async def persist_state_factory() -> KeyValueStore: self._robots_txt_file_cache: LRUCache[str, RobotsTxtFile] = LRUCache(maxsize=1000) self._robots_txt_lock = asyncio.Lock() self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) - self._request_throttler = RequestThrottler() + self._throttling_manager: ThrottlingRequestManager | None = None self._snapshotter = Snapshotter.from_config(config) self._autoscaled_pool = AutoscaledPool( system_status=SystemStatus(self._snapshotter), @@ -613,12 +614,18 @@ async def _get_proxy_info(self, request: Request, session: Session | None) -> Pr ) async def get_request_manager(self) -> RequestManager: - """Return the configured request manager. If none is configured, open and return the default request queue.""" + """Return the configured request manager. If none is configured, open and return the default request queue. + + The returned manager is wrapped with `ThrottlingRequestManager` to enforce + per-domain delays from 429 responses and robots.txt crawl-delay directives. + """ if not self._request_manager: - self._request_manager = await RequestQueue.open( + inner = await RequestQueue.open( storage_client=self._service_locator.get_storage_client(), configuration=self._service_locator.get_configuration(), ) + self._throttling_manager = ThrottlingRequestManager(inner) + self._request_manager = self._throttling_manager return self._request_manager @@ -1398,15 +1405,6 @@ async def __run_task_function(self) -> None: if request is None: return - # Check if this domain is currently rate-limited (429 backoff). - if self._request_throttler.is_throttled(request.url): - self._logger.debug( - f'Request to {request.url} delayed - domain is rate-limited ' - f'(retry in {self._request_throttler.get_delay(request.url).total_seconds():.1f}s)' - ) - await request_manager.reclaim_request(request) - return - if not (await self._is_allowed_based_on_robots_txt_file(request.url)): self._logger.warning( f'Skipping request {request.url} ({request.unique_key}) because it is disallowed based on robots.txt' @@ -1454,7 +1452,8 @@ async def __run_task_function(self) -> None: await self._mark_request_as_handled(request) # Record successful request to reset rate limit backoff for this domain. - self._request_throttler.record_success(request.url) + if self._throttling_manager: + self._throttling_manager.record_success(request.url) if session and session.is_usable: session.mark_good() @@ -1579,8 +1578,9 @@ def _raise_for_session_blocked_status_code( SessionError: If the status code indicates the session is blocked. """ if status_code == 429 and url: - retry_after = self._parse_retry_after_header(retry_after_header) - self._request_throttler.record_rate_limit(url, retry_after=retry_after) + retry_after = parse_retry_after_header(retry_after_header) + if self._throttling_manager: + self._throttling_manager.record_domain_delay(url, retry_after=retry_after) if session is not None and session.is_blocked_status_code( status_code=status_code, @@ -1588,41 +1588,7 @@ def _raise_for_session_blocked_status_code( ): raise SessionError(f'Assuming the session is blocked based on HTTP status code {status_code}') - @staticmethod - def _parse_retry_after_header(value: str | None) -> timedelta | None: - """Parse the Retry-After HTTP header value. - - The header can contain either a number of seconds or an HTTP-date. - See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After - - Args: - value: The raw Retry-After header value. - - Returns: - A timedelta representing the delay, or None if the header is missing or unparseable. - """ - if not value: - return None - - # Try parsing as integer seconds first. - try: - seconds = int(value) - return timedelta(seconds=seconds) - except ValueError: - pass - - # Try parsing as HTTP-date (e.g., "Wed, 21 Oct 2015 07:28:00 GMT"). - from email.utils import parsedate_to_datetime - - try: - retry_date = parsedate_to_datetime(value) - delay = retry_date - datetime.now(retry_date.tzinfo or timezone.utc) - if delay.total_seconds() > 0: - return delay - except (ValueError, TypeError): - pass - - return None + # NOTE: _parse_retry_after_header has been moved to crawlee._utils.http.parse_retry_after_header def _check_request_collision(self, request: Request, session: Session | None) -> None: """Raise an exception if a request cannot access required resources. @@ -1648,7 +1614,16 @@ async def _is_allowed_based_on_robots_txt_file(self, url: str) -> bool: if not self._respect_robots_txt_file: return True robots_txt_file = await self._get_robots_txt_file_for_url(url) - return not robots_txt_file or robots_txt_file.is_allowed(url) + if not robots_txt_file: + return True + + # Wire robots.txt crawl-delay into ThrottlingRequestManager (#1396). + if self._throttling_manager: + crawl_delay = robots_txt_file.get_crawl_delay() + if crawl_delay is not None: + self._throttling_manager.set_crawl_delay(url, crawl_delay) + + return robots_txt_file.is_allowed(url) async def _get_robots_txt_file_for_url(self, url: str) -> RobotsTxtFile | None: """Get the RobotsTxtFile for a given URL. diff --git a/src/crawlee/request_loaders/__init__.py b/src/crawlee/request_loaders/__init__.py index c04d9aa810..6dd8cccfab 100644 --- a/src/crawlee/request_loaders/__init__.py +++ b/src/crawlee/request_loaders/__init__.py @@ -3,5 +3,13 @@ from ._request_manager import RequestManager from ._request_manager_tandem import RequestManagerTandem from ._sitemap_request_loader import SitemapRequestLoader +from ._throttling_request_manager import ThrottlingRequestManager -__all__ = ['RequestList', 'RequestLoader', 'RequestManager', 'RequestManagerTandem', 'SitemapRequestLoader'] +__all__ = [ + 'RequestList', + 'RequestLoader', + 'RequestManager', + 'RequestManagerTandem', + 'SitemapRequestLoader', + 'ThrottlingRequestManager', +] diff --git a/src/crawlee/request_loaders/_throttling_request_manager.py b/src/crawlee/request_loaders/_throttling_request_manager.py new file mode 100644 index 0000000000..218ad82b8a --- /dev/null +++ b/src/crawlee/request_loaders/_throttling_request_manager.py @@ -0,0 +1,340 @@ +"""A request manager wrapper that enforces per-domain delays. + +Handles both HTTP 429 backoff and robots.txt crawl-delay at the scheduling layer, +eliminating the busy-wait problem described in https://github.com/apify/crawlee-python/issues/1437. + +Also addresses https://github.com/apify/crawlee-python/issues/1396 by providing a unified +delay mechanism for crawl-delay directives. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from logging import getLogger +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from typing_extensions import override + +from crawlee._utils.docs import docs_group +from crawlee.request_loaders._request_manager import RequestManager + +if TYPE_CHECKING: + from collections.abc import Sequence + + from crawlee._request import Request + from crawlee.storage_clients.models import ProcessedRequest + +logger = getLogger(__name__) + + +@dataclass +class _DomainState: + """Tracks delay state for a single domain.""" + + domain: str + """The domain being tracked.""" + + throttled_until: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + """Earliest time the next request to this domain is allowed.""" + + consecutive_429_count: int = 0 + """Number of consecutive 429 responses (for exponential backoff).""" + + crawl_delay: timedelta | None = None + """Minimum interval between requests from robots.txt crawl-delay directive.""" + + last_request_at: datetime | None = None + """When the last request to this domain was dispatched.""" + + +@docs_group('Request loaders') +class ThrottlingRequestManager(RequestManager): + """A request manager that wraps another and enforces per-domain delays. + + This moves throttling logic into the scheduling layer instead of the execution + layer. When `fetch_next_request()` is called, it intelligently handles delays: + + - If the next request's domain is not throttled, it returns immediately. + - If the domain is throttled but other requests are available, it buffers the + throttled request and tries the next one. + - If all available requests are throttled, it `asyncio.sleep()`s until the + earliest domain cooldown expires — eliminating busy-wait and unnecessary + queue writes. + + Delay sources: + - HTTP 429 responses (via `record_domain_delay`) + - robots.txt crawl-delay directives (via `set_crawl_delay`) + """ + + _BASE_DELAY = timedelta(seconds=2) + """Initial delay after the first 429 response from a domain.""" + + _MAX_DELAY = timedelta(seconds=60) + """Maximum delay between requests to a rate-limited domain.""" + + _MAX_BUFFER_SIZE = 50 + """Maximum number of requests to buffer before sleeping.""" + + def __init__(self, inner: RequestManager) -> None: + """Initialize the throttling manager. + + Args: + inner: The underlying request manager to wrap (typically a RequestQueue). + """ + self._inner = inner + self._domain_states: dict[str, _DomainState] = {} + self._buffered_requests: list[Request] = [] + + @staticmethod + def _extract_domain(url: str) -> str: + """Extract the domain (hostname) from a URL.""" + parsed = urlparse(url) + return parsed.hostname or '' + + def _get_or_create_state(self, domain: str) -> _DomainState: + """Get or create a domain state entry.""" + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + return self._domain_states[domain] + + def _is_domain_throttled(self, domain: str) -> bool: + """Check if a domain is currently throttled.""" + state = self._domain_states.get(domain) + if state is None: + return False + + now = datetime.now(timezone.utc) + + # Check 429 backoff. + if now < state.throttled_until: + return True + + # Check crawl-delay: enforce minimum interval between requests. + if state.crawl_delay is not None and state.last_request_at is not None: + if now < state.last_request_at + state.crawl_delay: + return True + + return False + + def _get_earliest_available_time(self) -> datetime: + """Get the earliest time any throttled domain becomes available.""" + now = datetime.now(timezone.utc) + earliest = now + self._MAX_DELAY # Fallback upper bound. + + for state in self._domain_states.values(): + # Consider 429 backoff. + if state.throttled_until > now and state.throttled_until < earliest: + earliest = state.throttled_until + + # Consider crawl-delay. + if state.crawl_delay is not None and state.last_request_at is not None: + next_allowed = state.last_request_at + state.crawl_delay + if next_allowed > now and next_allowed < earliest: + earliest = next_allowed + + return earliest + + def record_domain_delay(self, url: str, *, retry_after: timedelta | None = None) -> None: + """Record a 429 Too Many Requests response for the domain of the given URL. + + Increments the consecutive 429 count and calculates the next allowed + request time using exponential backoff or the Retry-After value. + + Args: + url: The URL that received a 429 response. + retry_after: Optional delay from the Retry-After header. If provided, + it takes priority over the calculated exponential backoff. + """ + domain = self._extract_domain(url) + if not domain: + return + + now = datetime.now(timezone.utc) + state = self._get_or_create_state(domain) + state.consecutive_429_count += 1 + + # Calculate delay: use Retry-After if provided, otherwise exponential backoff. + if retry_after is not None: + delay = retry_after + else: + delay = self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) + + # Cap the delay. + if delay > self._MAX_DELAY: + delay = self._MAX_DELAY + + state.throttled_until = now + delay + + logger.info( + f'Rate limit (429) detected for domain "{domain}" ' + f'(consecutive: {state.consecutive_429_count}, delay: {delay.total_seconds():.1f}s)' + ) + + def record_success(self, url: str) -> None: + """Record a successful request, resetting the backoff state for that domain. + + Args: + url: The URL that received a successful response. + """ + domain = self._extract_domain(url) + state = self._domain_states.get(domain) + + if state is not None and state.consecutive_429_count > 0: + logger.debug(f'Resetting rate limit state for domain "{domain}" after successful request') + state.consecutive_429_count = 0 + + def set_crawl_delay(self, url: str, delay_seconds: int) -> None: + """Set the robots.txt crawl-delay for a domain. + + Args: + url: A URL from the domain to throttle. + delay_seconds: The crawl-delay value in seconds. + """ + domain = self._extract_domain(url) + if not domain: + return + + state = self._get_or_create_state(domain) + state.crawl_delay = timedelta(seconds=delay_seconds) + + logger.debug(f'Set crawl-delay for domain "{domain}" to {delay_seconds}s') + + def _mark_domain_dispatched(self, url: str) -> None: + """Record that a request to this domain was just dispatched.""" + domain = self._extract_domain(url) + if domain: + state = self._get_or_create_state(domain) + state.last_request_at = datetime.now(timezone.utc) + + # ────────────────────────────────────────────────────── + # RequestManager interface delegation + smart scheduling + # ────────────────────────────────────────────────────── + + @override + async def drop(self) -> None: + self._buffered_requests.clear() + await self._inner.drop() + + @override + async def add_request(self, request: str | Request, *, forefront: bool = False) -> ProcessedRequest: + return await self._inner.add_request(request, forefront=forefront) + + @override + async def add_requests( + self, + requests: Sequence[str | Request], + *, + forefront: bool = False, + batch_size: int = 1000, + wait_time_between_batches: timedelta = timedelta(seconds=1), + wait_for_all_requests_to_be_added: bool = False, + wait_for_all_requests_to_be_added_timeout: timedelta | None = None, + ) -> None: + return await self._inner.add_requests( + requests, + forefront=forefront, + batch_size=batch_size, + wait_time_between_batches=wait_time_between_batches, + wait_for_all_requests_to_be_added=wait_for_all_requests_to_be_added, + wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout, + ) + + @override + async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: + return await self._inner.reclaim_request(request, forefront=forefront) + + @override + async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + return await self._inner.mark_request_as_handled(request) + + @override + async def get_handled_count(self) -> int: + return await self._inner.get_handled_count() + + @override + async def get_total_count(self) -> int: + return await self._inner.get_total_count() + + @override + async def is_empty(self) -> bool: + if self._buffered_requests: + return False + return await self._inner.is_empty() + + @override + async def is_finished(self) -> bool: + if self._buffered_requests: + return False + return await self._inner.is_finished() + + @override + async def fetch_next_request(self) -> Request | None: + """Fetch the next request, respecting per-domain delays. + + If the next available request belongs to a throttled domain, buffer it and + try the next one. If all available requests are throttled, sleep until the + earliest domain becomes available. + """ + # First, check if any buffered requests are now unthrottled. + still_throttled = [] + for req in self._buffered_requests: + domain = self._extract_domain(req.url) + if not self._is_domain_throttled(domain): + self._mark_domain_dispatched(req.url) + # Return remaining throttled requests to buffer. + self._buffered_requests = still_throttled + return req + still_throttled.append(req) + self._buffered_requests = still_throttled + + # Try fetching from the inner queue. + while True: + request = await self._inner.fetch_next_request() + + if request is None: + # No more requests in the queue. + if self._buffered_requests: + # There are buffered requests waiting for cooldown — sleep and retry. + earliest = self._get_earliest_available_time() + sleep_duration = max( + (earliest - datetime.now(timezone.utc)).total_seconds(), + 0.1, # Minimum sleep to avoid tight loops. + ) + logger.debug( + f'All {len(self._buffered_requests)} buffered request(s) throttled. ' + f'Sleeping {sleep_duration:.1f}s until earliest domain is available.' + ) + await asyncio.sleep(sleep_duration) + # After sleep, recursively try again. + return await self.fetch_next_request() + return None + + domain = self._extract_domain(request.url) + + if not self._is_domain_throttled(domain): + # Domain is clear — dispatch immediately. + self._mark_domain_dispatched(request.url) + return request + + # Domain is throttled — buffer this request. + logger.debug( + f'Request to {request.url} buffered — domain "{domain}" is throttled' + ) + self._buffered_requests.append(request) + + if len(self._buffered_requests) >= self._MAX_BUFFER_SIZE: + # Too many buffered: sleep until earliest cooldown and retry. + earliest = self._get_earliest_available_time() + sleep_duration = max( + (earliest - datetime.now(timezone.utc)).total_seconds(), + 0.1, + ) + logger.debug( + f'Buffer full ({self._MAX_BUFFER_SIZE} requests). ' + f'Sleeping {sleep_duration:.1f}s.' + ) + await asyncio.sleep(sleep_duration) + return await self.fetch_next_request() diff --git a/tests/unit/test_request_throttler.py b/tests/unit/test_request_throttler.py deleted file mode 100644 index d344b2d9f8..0000000000 --- a/tests/unit/test_request_throttler.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timedelta, timezone -from unittest.mock import patch - -import pytest - -from crawlee._request_throttler import RequestThrottler - - -class TestRequestThrottler: - """Tests for the RequestThrottler per-domain rate limit tracker.""" - - def test_not_throttled_by_default(self) -> None: - """Requests should not be throttled when no 429 has been recorded.""" - throttler = RequestThrottler() - assert not throttler.is_throttled('https://example.com/page1') - assert throttler.get_delay('https://example.com/page1') == timedelta(0) - - def test_throttled_after_rate_limit(self) -> None: - """A domain should be throttled after a 429 is recorded.""" - throttler = RequestThrottler() - throttler.record_rate_limit('https://example.com/page1') - - assert throttler.is_throttled('https://example.com/page2') - assert throttler.get_delay('https://example.com/page2') > timedelta(0) - - def test_different_domains_independent(self) -> None: - """A 429 on domain A should not affect domain B.""" - throttler = RequestThrottler() - throttler.record_rate_limit('https://example.com/page1') - - # example.com should be throttled - assert throttler.is_throttled('https://example.com/other') - - # other-site.com should NOT be throttled - assert not throttler.is_throttled('https://other-site.com/page1') - assert throttler.get_delay('https://other-site.com/page1') == timedelta(0) - - def test_exponential_backoff(self) -> None: - """Consecutive 429s should increase delay exponentially.""" - throttler = RequestThrottler() - - # First 429: delay = 2s (BASE_DELAY * 2^0) - throttler.record_rate_limit('https://example.com/a') - delay_1 = throttler.get_delay('https://example.com/a') - - # Second 429: delay = 4s (BASE_DELAY * 2^1) - throttler.record_rate_limit('https://example.com/b') - delay_2 = throttler.get_delay('https://example.com/b') - - # Third 429: delay = 8s (BASE_DELAY * 2^2) - throttler.record_rate_limit('https://example.com/c') - delay_3 = throttler.get_delay('https://example.com/c') - - # Each subsequent delay should be roughly double the previous - assert delay_2 > delay_1 - assert delay_3 > delay_2 - - def test_max_delay_cap(self) -> None: - """Delay should be capped at MAX_DELAY even with many consecutive 429s.""" - throttler = RequestThrottler() - - # Record many 429s to exceed MAX_DELAY - for _ in range(20): - throttler.record_rate_limit('https://example.com/page') - - delay = throttler.get_delay('https://example.com/page') - assert delay <= RequestThrottler._MAX_DELAY - - def test_success_resets_backoff(self) -> None: - """A successful request should reset the consecutive 429 count.""" - throttler = RequestThrottler() - - # Record multiple 429s - throttler.record_rate_limit('https://example.com/a') - throttler.record_rate_limit('https://example.com/b') - throttler.record_rate_limit('https://example.com/c') - - # Record a success - throttler.record_success('https://example.com/page') - - # The internal state should show 0 consecutive 429s - state = throttler._domain_states.get('example.com') - assert state is not None - assert state.consecutive_429_count == 0 - - def test_retry_after_takes_priority(self) -> None: - """Retry-After value should take priority over exponential backoff.""" - throttler = RequestThrottler() - - # Record 429 with a specific Retry-After of 30 seconds - throttler.record_rate_limit('https://example.com/page', retry_after=timedelta(seconds=30)) - - delay = throttler.get_delay('https://example.com/page') - # Delay should be close to 30s (minus time elapsed since recording) - assert delay > timedelta(seconds=29) - assert delay <= timedelta(seconds=30) - - def test_throttle_expires_after_delay(self) -> None: - """A domain should no longer be throttled after the delay expires.""" - throttler = RequestThrottler() - - # Record a 429 and manually set next_allowed_at to the past - throttler.record_rate_limit('https://example.com/page') - state = throttler._domain_states['example.com'] - state.next_allowed_at = datetime.now(timezone.utc) - timedelta(seconds=1) - - assert not throttler.is_throttled('https://example.com/page') - assert throttler.get_delay('https://example.com/page') == timedelta(0) - - def test_empty_url_handling(self) -> None: - """Empty or invalid URLs should not cause errors.""" - throttler = RequestThrottler() - - # These should not raise - throttler.record_rate_limit('') - throttler.record_success('') - assert not throttler.is_throttled('') - - -class TestParseRetryAfterHeader: - """Tests for BasicCrawler._parse_retry_after_header.""" - - def test_none_value(self) -> None: - """None input returns None.""" - from crawlee.crawlers._basic._basic_crawler import BasicCrawler - - assert BasicCrawler._parse_retry_after_header(None) is None - - def test_empty_string(self) -> None: - """Empty string returns None.""" - from crawlee.crawlers._basic._basic_crawler import BasicCrawler - - assert BasicCrawler._parse_retry_after_header('') is None - - def test_integer_seconds(self) -> None: - """Integer value should be parsed as seconds.""" - from crawlee.crawlers._basic._basic_crawler import BasicCrawler - - result = BasicCrawler._parse_retry_after_header('120') - assert result == timedelta(seconds=120) - - def test_invalid_value(self) -> None: - """Invalid values should return None.""" - from crawlee.crawlers._basic._basic_crawler import BasicCrawler - - assert BasicCrawler._parse_retry_after_header('not-a-number') is None diff --git a/tests/unit/test_throttling_request_manager.py b/tests/unit/test_throttling_request_manager.py new file mode 100644 index 0000000000..6b2e924d50 --- /dev/null +++ b/tests/unit/test_throttling_request_manager.py @@ -0,0 +1,323 @@ +"""Tests for ThrottlingRequestManager - per-domain delay scheduling. + +Tests cover: 429 backoff, robots.txt crawl-delay, domain independence, +exponential backoff, buffer + sleep behavior, and full RequestManager delegation. +""" + +from __future__ import annotations + +import asyncio +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from crawlee._request import Request +from crawlee.request_loaders._throttling_request_manager import ThrottlingRequestManager, _DomainState +from crawlee._utils.http import parse_retry_after_header + + +# ── Fixtures ────────────────────────────────────────────── + + +@pytest.fixture +def mock_inner() -> AsyncMock: + """Create a mock RequestManager to wrap.""" + inner = AsyncMock() + inner.fetch_next_request = AsyncMock(return_value=None) + inner.add_request = AsyncMock() + inner.add_requests = AsyncMock() + inner.reclaim_request = AsyncMock() + inner.mark_request_as_handled = AsyncMock() + inner.get_handled_count = AsyncMock(return_value=0) + inner.get_total_count = AsyncMock(return_value=0) + inner.is_empty = AsyncMock(return_value=True) + inner.is_finished = AsyncMock(return_value=True) + inner.drop = AsyncMock() + return inner + + +@pytest.fixture +def manager(mock_inner: AsyncMock) -> ThrottlingRequestManager: + """Create a ThrottlingRequestManager wrapping the mock.""" + return ThrottlingRequestManager(mock_inner) + + +def _make_request(url: str) -> Request: + """Helper to create a Request object.""" + return Request.from_url(url) + + +# ── Core Throttling Tests ───────────────────────────────── + + +class TestDomainThrottling: + """Tests for per-domain rate limiting.""" + + @pytest.mark.asyncio + async def test_non_throttled_passes_through(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """Requests for non-throttled domains should return immediately.""" + request = _make_request('https://example.com/page1') + mock_inner.fetch_next_request.return_value = request + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == 'https://example.com/page1' + + @pytest.mark.asyncio + async def test_429_triggers_domain_delay(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """After record_domain_delay(), the domain should be throttled.""" + manager.record_domain_delay('https://example.com/page1') + + assert manager._is_domain_throttled('example.com') + + @pytest.mark.asyncio + async def test_different_domains_independent(self, manager: ThrottlingRequestManager) -> None: + """Throttling example.com should NOT affect other-site.com.""" + manager.record_domain_delay('https://example.com/page1') + + assert manager._is_domain_throttled('example.com') + assert not manager._is_domain_throttled('other-site.com') + + @pytest.mark.asyncio + async def test_exponential_backoff(self, manager: ThrottlingRequestManager) -> None: + """Consecutive 429s should increase delay exponentially.""" + url = 'https://example.com/page1' + + # First 429: 2s delay. + manager.record_domain_delay(url) + state = manager._domain_states['example.com'] + first_until = state.throttled_until + + # Second 429: 4s delay. + manager.record_domain_delay(url) + second_until = state.throttled_until + + # The second delay should extend further into the future. + assert second_until > first_until + assert state.consecutive_429_count == 2 + + @pytest.mark.asyncio + async def test_max_delay_cap(self, manager: ThrottlingRequestManager) -> None: + """Backoff should cap at _MAX_DELAY (60s).""" + url = 'https://example.com/page1' + + # Trigger many 429s to hit the cap. + for _ in range(20): + manager.record_domain_delay(url) + + state = manager._domain_states['example.com'] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now + + # Should never exceed MAX_DELAY + small tolerance. + assert actual_delay <= manager._MAX_DELAY + timedelta(seconds=1) + + @pytest.mark.asyncio + async def test_retry_after_header_priority(self, manager: ThrottlingRequestManager) -> None: + """Explicit Retry-After should override exponential backoff.""" + url = 'https://example.com/page1' + + # Record with explicit 30s Retry-After. + manager.record_domain_delay(url, retry_after=timedelta(seconds=30)) + + state = manager._domain_states['example.com'] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now + + # Should be approximately 30s (within tolerance). + assert actual_delay > timedelta(seconds=28) + assert actual_delay <= timedelta(seconds=31) + + @pytest.mark.asyncio + async def test_success_resets_backoff(self, manager: ThrottlingRequestManager) -> None: + """Successful request should reset the consecutive 429 count.""" + url = 'https://example.com/page1' + + manager.record_domain_delay(url) + manager.record_domain_delay(url) + assert manager._domain_states['example.com'].consecutive_429_count == 2 + + manager.record_success(url) + assert manager._domain_states['example.com'].consecutive_429_count == 0 + + +# ── Crawl-Delay Integration Tests ───────────────────────── + + +class TestCrawlDelay: + """Tests for robots.txt crawl-delay integration (#1396).""" + + @pytest.mark.asyncio + async def test_crawl_delay_integration(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """set_crawl_delay() should enforce per-domain minimum interval.""" + url = 'https://example.com/page1' + manager.set_crawl_delay(url, 5) + + state = manager._domain_states['example.com'] + assert state.crawl_delay == timedelta(seconds=5) + + @pytest.mark.asyncio + async def test_crawl_delay_throttles_after_dispatch( + self, manager: ThrottlingRequestManager, mock_inner: AsyncMock + ) -> None: + """After dispatching a request, crawl-delay should throttle the next one.""" + url = 'https://example.com/page1' + manager.set_crawl_delay(url, 5) + + # Simulate dispatching (which sets last_request_at). + manager._mark_domain_dispatched(url) + + # Domain should now be throttled. + assert manager._is_domain_throttled('example.com') + + +# ── Sleep-Based Scheduling Tests ──────────────────────── + + +class TestSchedulingBehavior: + """Tests for the sleep-based scheduling that eliminates busy-wait.""" + + @pytest.mark.asyncio + async def test_mixed_throttled_and_unthrottled( + self, manager: ThrottlingRequestManager, mock_inner: AsyncMock + ) -> None: + """Throttled domain requests should be buffered; unthrottled ones returned.""" + throttled_req = _make_request('https://throttled.com/page1') + unthrottled_req = _make_request('https://free.com/page1') + + # Throttle one domain. + manager.record_domain_delay('https://throttled.com/page1') + + # Inner queue returns throttled first, then unthrottled. + mock_inner.fetch_next_request.side_effect = [throttled_req, unthrottled_req] + + result = await manager.fetch_next_request() + + # Should skip the throttled one and return the unthrottled one. + assert result is not None + assert result.url == 'https://free.com/page1' + # Throttled request should be in the buffer. + assert len(manager._buffered_requests) == 1 + + @pytest.mark.asyncio + async def test_sleep_instead_of_busy_wait( + self, manager: ThrottlingRequestManager, mock_inner: AsyncMock + ) -> None: + """When all domains are throttled and queue is empty, should sleep (not spin).""" + throttled_req = _make_request('https://throttled.com/page1') + + # Throttle the domain with a very short delay for test speed. + manager.record_domain_delay('https://throttled.com/page1', retry_after=timedelta(seconds=0.2)) + + # First call returns throttled request, second returns None (queue empty). + mock_inner.fetch_next_request.side_effect = [throttled_req, None] + + with patch('crawlee.request_loaders._throttling_request_manager.asyncio.sleep', new_callable=AsyncMock) as mock_sleep: + # Make sleep a no-op but track that it was called. + mock_sleep.return_value = None + + # After sleep, the buffered request should be returned. + # We need the recursive call to find the now-unthrottled buffered request. + # Reset throttle so the recursive call succeeds. + async def sleep_side_effect(duration: float) -> None: + # After sleeping, clear the throttle so the request can be dispatched. + manager._domain_states['throttled.com'].throttled_until = datetime.now(timezone.utc) + + mock_sleep.side_effect = sleep_side_effect + + result = await manager.fetch_next_request() + + # asyncio.sleep should have been called instead of busy-waiting. + mock_sleep.assert_called_once() + assert result is not None + assert result.url == 'https://throttled.com/page1' + + +# ── Delegation Tests ──────────────────────────────────── + + +class TestRequestManagerDelegation: + """Verify all RequestManager methods properly delegate to inner.""" + + @pytest.mark.asyncio + async def test_add_request_delegates(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.add_request(request) + mock_inner.add_request.assert_called_once_with(request, forefront=False) + + @pytest.mark.asyncio + async def test_reclaim_request_delegates(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.reclaim_request(request) + mock_inner.reclaim_request.assert_called_once_with(request, forefront=False) + + @pytest.mark.asyncio + async def test_mark_request_as_handled_delegates( + self, manager: ThrottlingRequestManager, mock_inner: AsyncMock + ) -> None: + request = _make_request('https://example.com') + await manager.mark_request_as_handled(request) + mock_inner.mark_request_as_handled.assert_called_once_with(request) + + @pytest.mark.asyncio + async def test_get_handled_count_delegates( + self, manager: ThrottlingRequestManager, mock_inner: AsyncMock + ) -> None: + mock_inner.get_handled_count.return_value = 42 + assert await manager.get_handled_count() == 42 + + @pytest.mark.asyncio + async def test_get_total_count_delegates( + self, manager: ThrottlingRequestManager, mock_inner: AsyncMock + ) -> None: + mock_inner.get_total_count.return_value = 100 + assert await manager.get_total_count() == 100 + + @pytest.mark.asyncio + async def test_is_empty_with_buffer(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """is_empty should return False if there are buffered requests.""" + mock_inner.is_empty.return_value = True + assert await manager.is_empty() is True + + # Add a buffered request. + manager._buffered_requests.append(_make_request('https://example.com')) + assert await manager.is_empty() is False + + @pytest.mark.asyncio + async def test_is_finished_with_buffer(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """is_finished should return False if there are buffered requests.""" + mock_inner.is_finished.return_value = True + assert await manager.is_finished() is True + + manager._buffered_requests.append(_make_request('https://example.com')) + assert await manager.is_finished() is False + + @pytest.mark.asyncio + async def test_drop_clears_buffer(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """drop() should clear the buffer and delegate.""" + manager._buffered_requests.append(_make_request('https://example.com')) + await manager.drop() + assert len(manager._buffered_requests) == 0 + mock_inner.drop.assert_called_once() + + +# ── Utility Tests ────────────────────────────────────── + + +class TestParseRetryAfterHeader: + """Tests for the extracted parse_retry_after_header utility.""" + + def test_none_value(self) -> None: + assert parse_retry_after_header(None) is None + + def test_empty_string(self) -> None: + assert parse_retry_after_header('') is None + + def test_integer_seconds(self) -> None: + result = parse_retry_after_header('120') + assert result == timedelta(seconds=120) + + def test_invalid_value(self) -> None: + assert parse_retry_after_header('not-a-date-or-number') is None From 1065e9b0baff78eca84bc938839569172f7dcc19 Mon Sep 17 00:00:00 2001 From: MrAliHasan Date: Wed, 25 Feb 2026 06:15:46 +0500 Subject: [PATCH 3/3] refactor: reimplement `ThrottlingRequestManager` with per-domain sub-queues and update its integration across crawlers. --- src/crawlee/_utils/http.py | 2 +- .../_abstract_http/_abstract_http_crawler.py | 2 +- src/crawlee/crawlers/_basic/_basic_crawler.py | 57 ++- .../_playwright/_playwright_crawler.py | 2 +- .../_throttling_request_manager.py | 152 +++--- .../crawlers/_basic/test_basic_crawler.py | 13 +- tests/unit/test_throttling_request_manager.py | 478 ++++++++++-------- 7 files changed, 391 insertions(+), 315 deletions(-) diff --git a/src/crawlee/_utils/http.py b/src/crawlee/_utils/http.py index e8f2249a98..be7a1fa5e4 100644 --- a/src/crawlee/_utils/http.py +++ b/src/crawlee/_utils/http.py @@ -15,7 +15,7 @@ def parse_retry_after_header(value: str | None) -> timedelta | None: value: The raw Retry-After header value. Returns: - A timedelta representing the delay, or None if the header is missing or unparseable. + A timedelta representing the delay, or None if the header is missing or unparsable. """ if not value: return None diff --git a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py index 6c598a3ae8..13bf2bd1b1 100644 --- a/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py +++ b/src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py @@ -282,7 +282,7 @@ async def _handle_status_code_response( self._raise_for_session_blocked_status_code( context.session, status_code, - url=context.request.url, + request_url=context.request.url, retry_after_header=context.http_response.headers.get('retry-after'), ) self._raise_for_error_status_code(status_code) diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 32e1957017..b514e1e5a5 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -12,7 +12,7 @@ from asyncio import CancelledError from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence from contextlib import AsyncExitStack, suppress -from datetime import datetime, timedelta, timezone +from datetime import timedelta from functools import partial from io import StringIO from pathlib import Path @@ -45,9 +45,8 @@ ) from crawlee._utils.docs import docs_group from crawlee._utils.file import atomic_write, export_csv_to_stream, export_json_to_stream -from crawlee._utils.recurring_task import RecurringTask from crawlee._utils.http import parse_retry_after_header -from crawlee.request_loaders import ThrottlingRequestManager +from crawlee._utils.recurring_task import RecurringTask from crawlee._utils.robots import RobotsTxtFile from crawlee._utils.urls import convert_to_absolute_url, is_url_absolute from crawlee._utils.wait import wait_for @@ -65,6 +64,7 @@ ) from crawlee.events._types import Event, EventCrawlerStatusData from crawlee.http_clients import ImpitHttpClient +from crawlee.request_loaders import ThrottlingRequestManager from crawlee.router import Router from crawlee.sessions import SessionPool from crawlee.statistics import Statistics, StatisticsState @@ -487,7 +487,6 @@ async def persist_state_factory() -> KeyValueStore: self._robots_txt_file_cache: LRUCache[str, RobotsTxtFile] = LRUCache(maxsize=1000) self._robots_txt_lock = asyncio.Lock() self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) - self._throttling_manager: ThrottlingRequestManager | None = None self._snapshotter = Snapshotter.from_config(config) self._autoscaled_pool = AutoscaledPool( system_status=SystemStatus(self._snapshotter), @@ -624,8 +623,7 @@ async def get_request_manager(self) -> RequestManager: storage_client=self._service_locator.get_storage_client(), configuration=self._service_locator.get_configuration(), ) - self._throttling_manager = ThrottlingRequestManager(inner) - self._request_manager = self._throttling_manager + self._request_manager = ThrottlingRequestManager(inner) return self._request_manager @@ -716,12 +714,21 @@ async def run( await self._session_pool.reset_store() request_manager = await self.get_request_manager() - if purge_request_queue and isinstance(request_manager, RequestQueue): - await request_manager.drop() - self._request_manager = await RequestQueue.open( - storage_client=self._service_locator.get_storage_client(), - configuration=self._service_locator.get_configuration(), - ) + if purge_request_queue: + if isinstance(request_manager, RequestQueue): + await request_manager.drop() + self._request_manager = await RequestQueue.open( + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + elif isinstance(request_manager, ThrottlingRequestManager): + await request_manager.drop() + inner = await RequestQueue.open( + storage_client=self._service_locator.get_storage_client(), + configuration=self._service_locator.get_configuration(), + ) + self._throttling_manager = ThrottlingRequestManager(inner) + self._request_manager = self._throttling_manager if requests is not None: await self.add_requests(requests) @@ -1452,8 +1459,8 @@ async def __run_task_function(self) -> None: await self._mark_request_as_handled(request) # Record successful request to reset rate limit backoff for this domain. - if self._throttling_manager: - self._throttling_manager.record_success(request.url) + if isinstance(request_manager, ThrottlingRequestManager): + request_manager.record_success(request.url) if session and session.is_usable: session.mark_good() @@ -1560,27 +1567,30 @@ def _raise_for_session_blocked_status_code( session: Session | None, status_code: int, *, - url: str = '', + request_url: str = '', retry_after_header: str | None = None, ) -> None: """Raise an exception if the given status code indicates the session is blocked. If the status code is 429 (Too Many Requests), the domain is recorded as - rate-limited in the `RequestThrottler` for per-domain backoff. + rate-limited in the `ThrottlingRequestManager` for per-domain backoff. Args: session: The session used for the request. If None, no check is performed. status_code: The HTTP status code to check. - url: The request URL, used for per-domain rate limit tracking. + request_url: The request URL, used for per-domain rate limit tracking. retry_after_header: The value of the Retry-After response header, if present. Raises: SessionError: If the status code indicates the session is blocked. """ - if status_code == 429 and url: + if status_code == 429 and request_url: # noqa: PLR2004 retry_after = parse_retry_after_header(retry_after_header) - if self._throttling_manager: - self._throttling_manager.record_domain_delay(url, retry_after=retry_after) + + # _request_manager might not be initialized yet if called directly or early, + # but usually it's set in get_request_manager(). + if isinstance(self._request_manager, ThrottlingRequestManager): + self._request_manager.record_domain_delay(request_url, retry_after=retry_after) if session is not None and session.is_blocked_status_code( status_code=status_code, @@ -1589,7 +1599,6 @@ def _raise_for_session_blocked_status_code( raise SessionError(f'Assuming the session is blocked based on HTTP status code {status_code}') # NOTE: _parse_retry_after_header has been moved to crawlee._utils.http.parse_retry_after_header - def _check_request_collision(self, request: Request, session: Session | None) -> None: """Raise an exception if a request cannot access required resources. @@ -1617,11 +1626,11 @@ async def _is_allowed_based_on_robots_txt_file(self, url: str) -> bool: if not robots_txt_file: return True - # Wire robots.txt crawl-delay into ThrottlingRequestManager (#1396). - if self._throttling_manager: + # Wire robots.txt crawl-delay into ThrottlingRequestManager + if isinstance(self._request_manager, ThrottlingRequestManager): crawl_delay = robots_txt_file.get_crawl_delay() if crawl_delay is not None: - self._throttling_manager.set_crawl_delay(url, crawl_delay) + self._request_manager.set_crawl_delay(url, crawl_delay) return robots_txt_file.is_allowed(url) diff --git a/src/crawlee/crawlers/_playwright/_playwright_crawler.py b/src/crawlee/crawlers/_playwright/_playwright_crawler.py index 073247a16e..9db340c976 100644 --- a/src/crawlee/crawlers/_playwright/_playwright_crawler.py +++ b/src/crawlee/crawlers/_playwright/_playwright_crawler.py @@ -463,7 +463,7 @@ async def _handle_status_code_response( self._raise_for_session_blocked_status_code( context.session, status_code, - url=context.request.url, + request_url=context.request.url, retry_after_header=retry_after_header, ) self._raise_for_error_status_code(status_code) diff --git a/src/crawlee/request_loaders/_throttling_request_manager.py b/src/crawlee/request_loaders/_throttling_request_manager.py index 218ad82b8a..b2bc40eeb4 100644 --- a/src/crawlee/request_loaders/_throttling_request_manager.py +++ b/src/crawlee/request_loaders/_throttling_request_manager.py @@ -20,6 +20,7 @@ from crawlee._utils.docs import docs_group from crawlee.request_loaders._request_manager import RequestManager +from crawlee.storages import RequestQueue if TYPE_CHECKING: from collections.abc import Sequence @@ -58,8 +59,8 @@ class ThrottlingRequestManager(RequestManager): layer. When `fetch_next_request()` is called, it intelligently handles delays: - If the next request's domain is not throttled, it returns immediately. - - If the domain is throttled but other requests are available, it buffers the - throttled request and tries the next one. + - If the domain is throttled but other requests are available, it moves the + throttled request to a per-domain sub-queue and tries the next one. - If all available requests are throttled, it `asyncio.sleep()`s until the earliest domain cooldown expires — eliminating busy-wait and unnecessary queue writes. @@ -75,9 +76,6 @@ class ThrottlingRequestManager(RequestManager): _MAX_DELAY = timedelta(seconds=60) """Maximum delay between requests to a rate-limited domain.""" - _MAX_BUFFER_SIZE = 50 - """Maximum number of requests to buffer before sleeping.""" - def __init__(self, inner: RequestManager) -> None: """Initialize the throttling manager. @@ -86,7 +84,9 @@ def __init__(self, inner: RequestManager) -> None: """ self._inner = inner self._domain_states: dict[str, _DomainState] = {} - self._buffered_requests: list[Request] = [] + self._sub_queues: dict[str, RequestQueue] = {} + self._dispatched_origins: dict[str, str] = {} + self._transferred_requests_count = 0 @staticmethod def _extract_domain(url: str) -> str: @@ -94,11 +94,11 @@ def _extract_domain(url: str) -> str: parsed = urlparse(url) return parsed.hostname or '' - def _get_or_create_state(self, domain: str) -> _DomainState: - """Get or create a domain state entry.""" - if domain not in self._domain_states: - self._domain_states[domain] = _DomainState(domain=domain) - return self._domain_states[domain] + async def _get_or_create_sub_queue(self, domain: str) -> RequestQueue: + """Get or create a per-domain sub-queue.""" + if domain not in self._sub_queues: + self._sub_queues[domain] = await RequestQueue.open(alias=f'throttled-{domain}') + return self._sub_queues[domain] def _is_domain_throttled(self, domain: str) -> bool: """Check if a domain is currently throttled.""" @@ -113,11 +113,11 @@ def _is_domain_throttled(self, domain: str) -> bool: return True # Check crawl-delay: enforce minimum interval between requests. - if state.crawl_delay is not None and state.last_request_at is not None: - if now < state.last_request_at + state.crawl_delay: - return True - - return False + return ( + state.crawl_delay is not None + and state.last_request_at is not None + and now < state.last_request_at + state.crawl_delay + ) def _get_earliest_available_time(self) -> datetime: """Get the earliest time any throttled domain becomes available.""" @@ -153,18 +153,16 @@ def record_domain_delay(self, url: str, *, retry_after: timedelta | None = None) return now = datetime.now(timezone.utc) - state = self._get_or_create_state(domain) + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] state.consecutive_429_count += 1 # Calculate delay: use Retry-After if provided, otherwise exponential backoff. - if retry_after is not None: - delay = retry_after - else: - delay = self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) + delay = retry_after if retry_after is not None else self._BASE_DELAY * (2 ** (state.consecutive_429_count - 1)) # Cap the delay. - if delay > self._MAX_DELAY: - delay = self._MAX_DELAY + delay = min(delay, self._MAX_DELAY) state.throttled_until = now + delay @@ -197,7 +195,9 @@ def set_crawl_delay(self, url: str, delay_seconds: int) -> None: if not domain: return - state = self._get_or_create_state(domain) + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] state.crawl_delay = timedelta(seconds=delay_seconds) logger.debug(f'Set crawl-delay for domain "{domain}" to {delay_seconds}s') @@ -206,7 +206,9 @@ def _mark_domain_dispatched(self, url: str) -> None: """Record that a request to this domain was just dispatched.""" domain = self._extract_domain(url) if domain: - state = self._get_or_create_state(domain) + if domain not in self._domain_states: + self._domain_states[domain] = _DomainState(domain=domain) + state = self._domain_states[domain] state.last_request_at = datetime.now(timezone.utc) # ────────────────────────────────────────────────────── @@ -215,8 +217,11 @@ def _mark_domain_dispatched(self, url: str) -> None: @override async def drop(self) -> None: - self._buffered_requests.clear() await self._inner.drop() + for sq in self._sub_queues.values(): + await sq.drop() + self._sub_queues.clear() + self._dispatched_origins.clear() @override async def add_request(self, request: str | Request, *, forefront: bool = False) -> ProcessedRequest: @@ -244,71 +249,91 @@ async def add_requests( @override async def reclaim_request(self, request: Request, *, forefront: bool = False) -> ProcessedRequest | None: + origin = self._dispatched_origins.get(request.unique_key) + if origin and origin != 'inner' and origin in self._sub_queues: + return await self._sub_queues[origin].reclaim_request(request, forefront=forefront) return await self._inner.reclaim_request(request, forefront=forefront) @override async def mark_request_as_handled(self, request: Request) -> ProcessedRequest | None: + origin = self._dispatched_origins.get(request.unique_key) + if origin and origin != 'inner' and origin in self._sub_queues: + return await self._sub_queues[origin].mark_request_as_handled(request) return await self._inner.mark_request_as_handled(request) @override async def get_handled_count(self) -> int: - return await self._inner.get_handled_count() + count = await self._inner.get_handled_count() + for sq in self._sub_queues.values(): + count += await sq.get_handled_count() + return max(0, count - self._transferred_requests_count) @override async def get_total_count(self) -> int: - return await self._inner.get_total_count() + count = await self._inner.get_total_count() + for sq in self._sub_queues.values(): + count += await sq.get_total_count() + return max(0, count - self._transferred_requests_count) @override async def is_empty(self) -> bool: - if self._buffered_requests: + if not await self._inner.is_empty(): return False - return await self._inner.is_empty() + for sq in self._sub_queues.values(): + if not await sq.is_empty(): + return False + return True @override async def is_finished(self) -> bool: - if self._buffered_requests: + if not await self._inner.is_finished(): return False - return await self._inner.is_finished() + for sq in self._sub_queues.values(): + if not await sq.is_finished(): + return False + return True @override async def fetch_next_request(self) -> Request | None: """Fetch the next request, respecting per-domain delays. - If the next available request belongs to a throttled domain, buffer it and - try the next one. If all available requests are throttled, sleep until the - earliest domain becomes available. + If the next available request belongs to a throttled domain, it is moved to + a per-domain sub-queue. If all unhandled requests are in throttled sub-queues, + it sleeps until the earliest domain becomes available. """ - # First, check if any buffered requests are now unthrottled. - still_throttled = [] - for req in self._buffered_requests: - domain = self._extract_domain(req.url) + # First, check if any sub-queues have unthrottled requests. + for domain, sq in self._sub_queues.items(): if not self._is_domain_throttled(domain): - self._mark_domain_dispatched(req.url) - # Return remaining throttled requests to buffer. - self._buffered_requests = still_throttled - return req - still_throttled.append(req) - self._buffered_requests = still_throttled + req = await sq.fetch_next_request() + if req: + self._mark_domain_dispatched(req.url) + self._dispatched_origins[req.unique_key] = domain + return req # Try fetching from the inner queue. while True: request = await self._inner.fetch_next_request() if request is None: - # No more requests in the queue. - if self._buffered_requests: - # There are buffered requests waiting for cooldown — sleep and retry. + # No more requests in inner queue. Check if sub-queues have requests. + have_sq_requests = False + for sq in self._sub_queues.values(): + if not await sq.is_empty(): + have_sq_requests = True + break + + if have_sq_requests: + # Requests exist but domains are throttled. Sleep and retry. earliest = self._get_earliest_available_time() sleep_duration = max( (earliest - datetime.now(timezone.utc)).total_seconds(), 0.1, # Minimum sleep to avoid tight loops. ) logger.debug( - f'All {len(self._buffered_requests)} buffered request(s) throttled. ' + f'Throttled sub-queues have requests. ' f'Sleeping {sleep_duration:.1f}s until earliest domain is available.' ) await asyncio.sleep(sleep_duration) - # After sleep, recursively try again. return await self.fetch_next_request() return None @@ -317,24 +342,17 @@ async def fetch_next_request(self) -> Request | None: if not self._is_domain_throttled(domain): # Domain is clear — dispatch immediately. self._mark_domain_dispatched(request.url) + self._dispatched_origins[request.unique_key] = 'inner' return request - # Domain is throttled — buffer this request. + # Domain is throttled — move this request to the sub-queue. logger.debug( - f'Request to {request.url} buffered — domain "{domain}" is throttled' + f'Request to {request.url} moved to sub-queue — domain "{domain}" is throttled' ) - self._buffered_requests.append(request) - - if len(self._buffered_requests) >= self._MAX_BUFFER_SIZE: - # Too many buffered: sleep until earliest cooldown and retry. - earliest = self._get_earliest_available_time() - sleep_duration = max( - (earliest - datetime.now(timezone.utc)).total_seconds(), - 0.1, - ) - logger.debug( - f'Buffer full ({self._MAX_BUFFER_SIZE} requests). ' - f'Sleeping {sleep_duration:.1f}s.' - ) - await asyncio.sleep(sleep_duration) - return await self.fetch_next_request() + sq = await self._get_or_create_sub_queue(domain) + await sq.add_request(request) + + # Mark it handled in inner so it is not processed twice. + # We track transfer count to correct get_total_count and get_handled_count. + await self._inner.mark_request_as_handled(request) + self._transferred_requests_count += 1 diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 23ca3c1eca..84c216066f 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -28,7 +28,7 @@ from crawlee.errors import RequestCollisionError, SessionError, UserDefinedErrorHandlerError from crawlee.events import Event, EventCrawlerStatusData from crawlee.events._local_event_manager import LocalEventManager -from crawlee.request_loaders import RequestList, RequestManagerTandem +from crawlee.request_loaders import RequestList, RequestManagerTandem, ThrottlingRequestManager from crawlee.sessions import Session, SessionPool from crawlee.statistics import FinalStatistics from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient @@ -653,7 +653,7 @@ async def test_crawler_get_storages() -> None: crawler = BasicCrawler() rp = await crawler.get_request_manager() - assert isinstance(rp, RequestQueue) + assert isinstance(rp, ThrottlingRequestManager) dataset = await crawler.get_dataset() assert isinstance(dataset, Dataset) @@ -1238,7 +1238,8 @@ async def test_crawler_uses_default_storages(tmp_path: Path) -> None: assert dataset is await crawler.get_dataset() assert kvs is await crawler.get_key_value_store() - assert rq is await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert (manager._inner if isinstance(manager, ThrottlingRequestManager) else manager) is rq async def test_crawler_can_use_other_storages(tmp_path: Path) -> None: @@ -1256,7 +1257,8 @@ async def test_crawler_can_use_other_storages(tmp_path: Path) -> None: assert dataset is not await crawler.get_dataset() assert kvs is not await crawler.get_key_value_store() - assert rq is not await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert (manager._inner if isinstance(manager, ThrottlingRequestManager) else manager) is not rq async def test_crawler_can_use_other_storages_of_same_type(tmp_path: Path) -> None: @@ -1293,7 +1295,8 @@ async def test_crawler_can_use_other_storages_of_same_type(tmp_path: Path) -> No # Assert that the storages are different assert dataset is not await crawler.get_dataset() assert kvs is not await crawler.get_key_value_store() - assert rq is not await crawler.get_request_manager() + manager = await crawler.get_request_manager() + assert (manager._inner if isinstance(manager, ThrottlingRequestManager) else manager) is not rq # Assert that all storages exists on the filesystem for path in expected_paths: diff --git a/tests/unit/test_throttling_request_manager.py b/tests/unit/test_throttling_request_manager.py index 6b2e924d50..307aaa8fe2 100644 --- a/tests/unit/test_throttling_request_manager.py +++ b/tests/unit/test_throttling_request_manager.py @@ -1,23 +1,15 @@ -"""Tests for ThrottlingRequestManager - per-domain delay scheduling. - -Tests cover: 429 backoff, robots.txt crawl-delay, domain independence, -exponential backoff, buffer + sleep behavior, and full RequestManager delegation. -""" +"""Tests for ThrottlingRequestManager - per-domain delay scheduling.""" from __future__ import annotations -import asyncio from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest from crawlee._request import Request -from crawlee.request_loaders._throttling_request_manager import ThrottlingRequestManager, _DomainState from crawlee._utils.http import parse_retry_after_header - - -# ── Fixtures ────────────────────────────────────────────── +from crawlee.request_loaders._throttling_request_manager import ThrottlingRequestManager @pytest.fixture @@ -43,6 +35,28 @@ def manager(mock_inner: AsyncMock) -> ThrottlingRequestManager: return ThrottlingRequestManager(mock_inner) +@pytest.fixture(autouse=True) +def mock_request_queue_open() -> AsyncMock: + """Mock RequestQueue.open to avoid hitting real storage during tests.""" + target = 'crawlee.request_loaders._throttling_request_manager.RequestQueue.open' + with patch(target, new_callable=AsyncMock) as mocked: + async def mock_open(*args: any, **kwargs: any) -> AsyncMock: # noqa: ARG001 + sq = AsyncMock() + sq.fetch_next_request = AsyncMock(return_value=None) + sq.add_request = AsyncMock() + sq.reclaim_request = AsyncMock() + sq.mark_request_as_handled = AsyncMock() + sq.get_handled_count = AsyncMock(return_value=0) + sq.get_total_count = AsyncMock(return_value=0) + sq.is_empty = AsyncMock(return_value=True) + sq.is_finished = AsyncMock(return_value=True) + sq.drop = AsyncMock() + return sq + + mocked.side_effect = mock_open + yield mocked + + def _make_request(url: str) -> Request: """Helper to create a Request object.""" return Request.from_url(url) @@ -51,273 +65,305 @@ def _make_request(url: str) -> Request: # ── Core Throttling Tests ───────────────────────────────── -class TestDomainThrottling: - """Tests for per-domain rate limiting.""" +@pytest.mark.asyncio +async def test_non_throttled_passes_through(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """Requests for non-throttled domains should return immediately.""" + request = _make_request('https://example.com/page1') + mock_inner.fetch_next_request.return_value = request - @pytest.mark.asyncio - async def test_non_throttled_passes_through(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - """Requests for non-throttled domains should return immediately.""" - request = _make_request('https://example.com/page1') - mock_inner.fetch_next_request.return_value = request + result = await manager.fetch_next_request() - result = await manager.fetch_next_request() + assert result is not None + assert result.url == 'https://example.com/page1' - assert result is not None - assert result.url == 'https://example.com/page1' - @pytest.mark.asyncio - async def test_429_triggers_domain_delay(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - """After record_domain_delay(), the domain should be throttled.""" - manager.record_domain_delay('https://example.com/page1') +@pytest.mark.asyncio +async def test_429_triggers_domain_delay(manager: ThrottlingRequestManager) -> None: + """After record_domain_delay(), the domain should be throttled.""" + manager.record_domain_delay('https://example.com/page1') - assert manager._is_domain_throttled('example.com') + assert manager._is_domain_throttled('example.com') - @pytest.mark.asyncio - async def test_different_domains_independent(self, manager: ThrottlingRequestManager) -> None: - """Throttling example.com should NOT affect other-site.com.""" - manager.record_domain_delay('https://example.com/page1') - assert manager._is_domain_throttled('example.com') - assert not manager._is_domain_throttled('other-site.com') +@pytest.mark.asyncio +async def test_different_domains_independent(manager: ThrottlingRequestManager) -> None: + """Throttling example.com should NOT affect other-site.com.""" + manager.record_domain_delay('https://example.com/page1') - @pytest.mark.asyncio - async def test_exponential_backoff(self, manager: ThrottlingRequestManager) -> None: - """Consecutive 429s should increase delay exponentially.""" - url = 'https://example.com/page1' + assert manager._is_domain_throttled('example.com') + assert not manager._is_domain_throttled('other-site.com') - # First 429: 2s delay. - manager.record_domain_delay(url) - state = manager._domain_states['example.com'] - first_until = state.throttled_until - # Second 429: 4s delay. - manager.record_domain_delay(url) - second_until = state.throttled_until +@pytest.mark.asyncio +async def test_exponential_backoff(manager: ThrottlingRequestManager) -> None: + """Consecutive 429s should increase delay exponentially.""" + url = 'https://example.com/page1' + + manager.record_domain_delay(url) + state = manager._domain_states['example.com'] + first_until = state.throttled_until - # The second delay should extend further into the future. - assert second_until > first_until - assert state.consecutive_429_count == 2 + manager.record_domain_delay(url) + second_until = state.throttled_until - @pytest.mark.asyncio - async def test_max_delay_cap(self, manager: ThrottlingRequestManager) -> None: - """Backoff should cap at _MAX_DELAY (60s).""" - url = 'https://example.com/page1' + assert second_until > first_until + assert state.consecutive_429_count == 2 - # Trigger many 429s to hit the cap. - for _ in range(20): - manager.record_domain_delay(url) - state = manager._domain_states['example.com'] - now = datetime.now(timezone.utc) - actual_delay = state.throttled_until - now +@pytest.mark.asyncio +async def test_max_delay_cap(manager: ThrottlingRequestManager) -> None: + """Backoff should cap at _MAX_DELAY (60s).""" + url = 'https://example.com/page1' - # Should never exceed MAX_DELAY + small tolerance. - assert actual_delay <= manager._MAX_DELAY + timedelta(seconds=1) + for _ in range(20): + manager.record_domain_delay(url) - @pytest.mark.asyncio - async def test_retry_after_header_priority(self, manager: ThrottlingRequestManager) -> None: - """Explicit Retry-After should override exponential backoff.""" - url = 'https://example.com/page1' + state = manager._domain_states['example.com'] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now - # Record with explicit 30s Retry-After. - manager.record_domain_delay(url, retry_after=timedelta(seconds=30)) + assert actual_delay <= manager._MAX_DELAY + timedelta(seconds=1) - state = manager._domain_states['example.com'] - now = datetime.now(timezone.utc) - actual_delay = state.throttled_until - now - # Should be approximately 30s (within tolerance). - assert actual_delay > timedelta(seconds=28) - assert actual_delay <= timedelta(seconds=31) +@pytest.mark.asyncio +async def test_retry_after_header_priority(manager: ThrottlingRequestManager) -> None: + """Explicit Retry-After should override exponential backoff.""" + url = 'https://example.com/page1' - @pytest.mark.asyncio - async def test_success_resets_backoff(self, manager: ThrottlingRequestManager) -> None: - """Successful request should reset the consecutive 429 count.""" - url = 'https://example.com/page1' + manager.record_domain_delay(url, retry_after=timedelta(seconds=30)) - manager.record_domain_delay(url) - manager.record_domain_delay(url) - assert manager._domain_states['example.com'].consecutive_429_count == 2 + state = manager._domain_states['example.com'] + now = datetime.now(timezone.utc) + actual_delay = state.throttled_until - now - manager.record_success(url) - assert manager._domain_states['example.com'].consecutive_429_count == 0 + assert actual_delay > timedelta(seconds=28) + assert actual_delay <= timedelta(seconds=31) + + +@pytest.mark.asyncio +async def test_success_resets_backoff(manager: ThrottlingRequestManager) -> None: + """Successful request should reset the consecutive 429 count.""" + url = 'https://example.com/page1' + + manager.record_domain_delay(url) + manager.record_domain_delay(url) + assert manager._domain_states['example.com'].consecutive_429_count == 2 + + manager.record_success(url) + assert manager._domain_states['example.com'].consecutive_429_count == 0 # ── Crawl-Delay Integration Tests ───────────────────────── -class TestCrawlDelay: - """Tests for robots.txt crawl-delay integration (#1396).""" +@pytest.mark.asyncio +async def test_crawl_delay_integration(manager: ThrottlingRequestManager) -> None: + """set_crawl_delay() should enforce per-domain minimum interval.""" + url = 'https://example.com/page1' + manager.set_crawl_delay(url, 5) - @pytest.mark.asyncio - async def test_crawl_delay_integration(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - """set_crawl_delay() should enforce per-domain minimum interval.""" - url = 'https://example.com/page1' - manager.set_crawl_delay(url, 5) + state = manager._domain_states['example.com'] + assert state.crawl_delay == timedelta(seconds=5) - state = manager._domain_states['example.com'] - assert state.crawl_delay == timedelta(seconds=5) - @pytest.mark.asyncio - async def test_crawl_delay_throttles_after_dispatch( - self, manager: ThrottlingRequestManager, mock_inner: AsyncMock - ) -> None: - """After dispatching a request, crawl-delay should throttle the next one.""" - url = 'https://example.com/page1' - manager.set_crawl_delay(url, 5) +@pytest.mark.asyncio +async def test_crawl_delay_throttles_after_dispatch(manager: ThrottlingRequestManager) -> None: + """After dispatching a request, crawl-delay should throttle the next one.""" + url = 'https://example.com/page1' + manager.set_crawl_delay(url, 5) - # Simulate dispatching (which sets last_request_at). - manager._mark_domain_dispatched(url) + manager._mark_domain_dispatched(url) - # Domain should now be throttled. - assert manager._is_domain_throttled('example.com') + assert manager._is_domain_throttled('example.com') # ── Sleep-Based Scheduling Tests ──────────────────────── -class TestSchedulingBehavior: - """Tests for the sleep-based scheduling that eliminates busy-wait.""" +@pytest.mark.asyncio +async def test_mixed_throttled_and_unthrottled( + manager: ThrottlingRequestManager, + mock_inner: AsyncMock, + mock_request_queue_open: AsyncMock, +) -> None: + """Throttled domain requests should be moved to sub-queues; unthrottled ones returned.""" + throttled_req = _make_request('https://throttled.com/page1') + unthrottled_req = _make_request('https://free.com/page1') + + manager.record_domain_delay('https://throttled.com/page1') + + # inner returns throttled, then unthrottled + mock_inner.fetch_next_request.side_effect = [throttled_req, unthrottled_req] + + result = await manager.fetch_next_request() + + assert result is not None + assert result.url == 'https://free.com/page1' + + # Verify throttled request was moved to sub-queue + mock_request_queue_open.assert_called_once() + assert 'throttled.com' in manager._sub_queues + + sq = manager._sub_queues['throttled.com'] + sq.add_request.assert_called_once_with(throttled_req) + + +@pytest.mark.asyncio +async def test_sleep_instead_of_busy_wait(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + """When all domains are throttled and queue is empty, should sleep (not spin).""" + throttled_req = _make_request('https://throttled.com/page1') - @pytest.mark.asyncio - async def test_mixed_throttled_and_unthrottled( - self, manager: ThrottlingRequestManager, mock_inner: AsyncMock - ) -> None: - """Throttled domain requests should be buffered; unthrottled ones returned.""" - throttled_req = _make_request('https://throttled.com/page1') - unthrottled_req = _make_request('https://free.com/page1') + manager.record_domain_delay('https://throttled.com/page1', retry_after=timedelta(seconds=0.2)) - # Throttle one domain. - manager.record_domain_delay('https://throttled.com/page1') + # inner queue returns the request first time, then None + mock_inner.fetch_next_request.side_effect = [throttled_req, None] - # Inner queue returns throttled first, then unthrottled. - mock_inner.fetch_next_request.side_effect = [throttled_req, unthrottled_req] + target = 'crawlee.request_loaders._throttling_request_manager.asyncio.sleep' + with patch(target, new_callable=AsyncMock) as mock_sleep: + # Instead of actually sleeping, we simulate the time passing by unthrottling the domain + async def sleep_side_effect(*args: any, **kwargs: any) -> None: # noqa: ARG001 + # Clear throttle so recursive call succeeds + manager._domain_states['throttled.com'].throttled_until = datetime.now(timezone.utc) + # Setup the sub-queue to return the request now + sq = manager._sub_queues['throttled.com'] + sq.fetch_next_request.side_effect = [throttled_req, None] + # Must return False then True so loop proceeds + sq.is_empty.side_effect = [False, True] + + mock_sleep.side_effect = sleep_side_effect + + # When request is moved to sub-queue, we must ensure it isn't "empty" so it triggers sleep + async def mock_add_request(*args: any, **kwargs: any) -> None: + sq = manager._sub_queues['throttled.com'] + sq.is_empty.return_value = False + manager._sub_queues = {'throttled.com': AsyncMock()} + manager._sub_queues['throttled.com'].add_request.side_effect = mock_add_request + manager._sub_queues['throttled.com'].is_empty.return_value = True result = await manager.fetch_next_request() - # Should skip the throttled one and return the unthrottled one. + mock_sleep.assert_called_once() assert result is not None - assert result.url == 'https://free.com/page1' - # Throttled request should be in the buffer. - assert len(manager._buffered_requests) == 1 + assert result.url == 'https://throttled.com/page1' + - @pytest.mark.asyncio - async def test_sleep_instead_of_busy_wait( - self, manager: ThrottlingRequestManager, mock_inner: AsyncMock - ) -> None: - """When all domains are throttled and queue is empty, should sleep (not spin).""" - throttled_req = _make_request('https://throttled.com/page1') +# ── Delegation Tests ──────────────────────────────────── - # Throttle the domain with a very short delay for test speed. - manager.record_domain_delay('https://throttled.com/page1', retry_after=timedelta(seconds=0.2)) - # First call returns throttled request, second returns None (queue empty). - mock_inner.fetch_next_request.side_effect = [throttled_req, None] +@pytest.mark.asyncio +async def test_add_request_delegates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.add_request(request) + mock_inner.add_request.assert_called_once_with(request, forefront=False) - with patch('crawlee.request_loaders._throttling_request_manager.asyncio.sleep', new_callable=AsyncMock) as mock_sleep: - # Make sleep a no-op but track that it was called. - mock_sleep.return_value = None - # After sleep, the buffered request should be returned. - # We need the recursive call to find the now-unthrottled buffered request. - # Reset throttle so the recursive call succeeds. - async def sleep_side_effect(duration: float) -> None: - # After sleeping, clear the throttle so the request can be dispatched. - manager._domain_states['throttled.com'].throttled_until = datetime.now(timezone.utc) +@pytest.mark.asyncio +async def test_reclaim_request_delegates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.reclaim_request(request) + mock_inner.reclaim_request.assert_called_once_with(request, forefront=False) - mock_sleep.side_effect = sleep_side_effect - result = await manager.fetch_next_request() +@pytest.mark.asyncio +async def test_reclaim_request_delegates_to_sub_queue(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + + # Setup state manually assuming it was fetched from a sub-queue + sq = AsyncMock() + manager._sub_queues['example.com'] = sq + manager._dispatched_origins[request.unique_key] = 'example.com' - # asyncio.sleep should have been called instead of busy-waiting. - mock_sleep.assert_called_once() - assert result is not None - assert result.url == 'https://throttled.com/page1' + await manager.reclaim_request(request) + sq.reclaim_request.assert_called_once_with(request, forefront=False) + mock_inner.reclaim_request.assert_not_called() -# ── Delegation Tests ──────────────────────────────────── + +@pytest.mark.asyncio +async def test_mark_request_as_handled_delegates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + await manager.mark_request_as_handled(request) + mock_inner.mark_request_as_handled.assert_called_once_with(request) + + +@pytest.mark.asyncio +async def test_get_handled_count_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.get_handled_count.return_value = 42 + + sq = AsyncMock() + sq.get_handled_count.return_value = 10 + manager._sub_queues['example.com'] = sq + manager._transferred_requests_count = 5 + + assert await manager.get_handled_count() == 47 + + +@pytest.mark.asyncio +async def test_get_total_count_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.get_total_count.return_value = 100 + + sq = AsyncMock() + sq.get_total_count.return_value = 20 + manager._sub_queues['example.com'] = sq + manager._transferred_requests_count = 10 + + assert await manager.get_total_count() == 110 -class TestRequestManagerDelegation: - """Verify all RequestManager methods properly delegate to inner.""" - - @pytest.mark.asyncio - async def test_add_request_delegates(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - request = _make_request('https://example.com') - await manager.add_request(request) - mock_inner.add_request.assert_called_once_with(request, forefront=False) - - @pytest.mark.asyncio - async def test_reclaim_request_delegates(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - request = _make_request('https://example.com') - await manager.reclaim_request(request) - mock_inner.reclaim_request.assert_called_once_with(request, forefront=False) - - @pytest.mark.asyncio - async def test_mark_request_as_handled_delegates( - self, manager: ThrottlingRequestManager, mock_inner: AsyncMock - ) -> None: - request = _make_request('https://example.com') - await manager.mark_request_as_handled(request) - mock_inner.mark_request_as_handled.assert_called_once_with(request) - - @pytest.mark.asyncio - async def test_get_handled_count_delegates( - self, manager: ThrottlingRequestManager, mock_inner: AsyncMock - ) -> None: - mock_inner.get_handled_count.return_value = 42 - assert await manager.get_handled_count() == 42 - - @pytest.mark.asyncio - async def test_get_total_count_delegates( - self, manager: ThrottlingRequestManager, mock_inner: AsyncMock - ) -> None: - mock_inner.get_total_count.return_value = 100 - assert await manager.get_total_count() == 100 - - @pytest.mark.asyncio - async def test_is_empty_with_buffer(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - """is_empty should return False if there are buffered requests.""" - mock_inner.is_empty.return_value = True - assert await manager.is_empty() is True - - # Add a buffered request. - manager._buffered_requests.append(_make_request('https://example.com')) - assert await manager.is_empty() is False - - @pytest.mark.asyncio - async def test_is_finished_with_buffer(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - """is_finished should return False if there are buffered requests.""" - mock_inner.is_finished.return_value = True - assert await manager.is_finished() is True - - manager._buffered_requests.append(_make_request('https://example.com')) - assert await manager.is_finished() is False - - @pytest.mark.asyncio - async def test_drop_clears_buffer(self, manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: - """drop() should clear the buffer and delegate.""" - manager._buffered_requests.append(_make_request('https://example.com')) - await manager.drop() - assert len(manager._buffered_requests) == 0 - mock_inner.drop.assert_called_once() +@pytest.mark.asyncio +async def test_is_empty_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.is_empty.return_value = True + assert await manager.is_empty() is True + + sq = AsyncMock() + sq.is_empty.return_value = False + manager._sub_queues['example.com'] = sq + + assert await manager.is_empty() is False + + +@pytest.mark.asyncio +async def test_is_finished_aggregates(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + mock_inner.is_finished.return_value = True + assert await manager.is_finished() is True + + sq = AsyncMock() + sq.is_finished.return_value = False + manager._sub_queues['example.com'] = sq + + assert await manager.is_finished() is False + + +@pytest.mark.asyncio +async def test_drop_clears_all(manager: ThrottlingRequestManager, mock_inner: AsyncMock) -> None: + request = _make_request('https://example.com') + sq = AsyncMock() + manager._sub_queues['example.com'] = sq + manager._dispatched_origins[request.unique_key] = 'inner' + + await manager.drop() + + mock_inner.drop.assert_called_once() + sq.drop.assert_called_once() + assert len(manager._sub_queues) == 0 + assert len(manager._dispatched_origins) == 0 # ── Utility Tests ────────────────────────────────────── -class TestParseRetryAfterHeader: - """Tests for the extracted parse_retry_after_header utility.""" +def test_parse_retry_after_none_value() -> None: + assert parse_retry_after_header(None) is None + + +def test_parse_retry_after_empty_string() -> None: + assert parse_retry_after_header('') is None - def test_none_value(self) -> None: - assert parse_retry_after_header(None) is None - def test_empty_string(self) -> None: - assert parse_retry_after_header('') is None +def test_parse_retry_after_integer_seconds() -> None: + result = parse_retry_after_header('120') + assert result == timedelta(seconds=120) - def test_integer_seconds(self) -> None: - result = parse_retry_after_header('120') - assert result == timedelta(seconds=120) - def test_invalid_value(self) -> None: - assert parse_retry_after_header('not-a-date-or-number') is None +def test_parse_retry_after_invalid_value() -> None: + assert parse_retry_after_header('not-a-date-or-number') is None