diff --git a/HISTORY.rst b/HISTORY.rst index 3dd2f5a6..cac8b4f1 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,7 +4,13 @@ History ------- -5.0.0 +5.0.1 (2025-01-28) +++++++++++++++++++ + +* Allow ``ip_address`` in the ``Traits`` record to be ``None`` again. The + primary use case for this is from the ``minfraud`` package. + +5.0.0 (2025-01-28) ++++++++++++++++++ * BREAKING: The ``raw`` attribute on the model classes has been replaced diff --git a/docs/conf.py b/docs/conf.py index e8526e92..dc9d5ee4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # # geoip2 documentation build configuration file, created by # sphinx-quickstart on Tue Apr 9 13:34:57 2013. @@ -12,8 +11,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os +import sys # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the diff --git a/examples/benchmark.py b/examples/benchmark.py index 4a60afcc..b9fbc7c2 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -1,15 +1,15 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- -from __future__ import print_function import argparse -import geoip2.database +import contextlib import random import socket import struct import timeit +import geoip2.database + parser = argparse.ArgumentParser(description="Benchmark maxminddb.") parser.add_argument("--count", default=250000, type=int, help="number of lookups") parser.add_argument("--mode", default=0, type=int, help="reader mode to use") @@ -20,12 +20,10 @@ reader = geoip2.database.Reader(args.file, mode=args.mode) -def lookup_ip_address(): +def lookup_ip_address() -> None: ip = socket.inet_ntoa(struct.pack("!L", random.getrandbits(32))) - try: - record = reader.city(str(ip)) - except geoip2.errors.AddressNotFoundError: - pass + with contextlib.suppress(geoip2.errors.AddressNotFoundError): + reader.city(str(ip)) elapsed = timeit.timeit( diff --git a/geoip2/__init__.py b/geoip2/__init__.py index 43c322a7..2b13eaf7 100644 --- a/geoip2/__init__.py +++ b/geoip2/__init__.py @@ -1,7 +1,7 @@ # pylint:disable=C0111 __title__ = "geoip2" -__version__ = "4.8.1" +__version__ = "5.0.1" __author__ = "Gregory Oschwald" __license__ = "Apache License, Version 2.0" __copyright__ = "Copyright (c) 2013-2025 MaxMind, Inc." diff --git a/geoip2/_internal.py b/geoip2/_internal.py index 6f37ced5..e1970c7e 100644 --- a/geoip2/_internal.py +++ b/geoip2/_internal.py @@ -1,22 +1,21 @@ -"""This package contains internal utilities""" +"""This package contains internal utilities.""" # pylint: disable=too-few-public-methods from abc import ABCMeta -from typing import Any class Model(metaclass=ABCMeta): - """Shared methods for MaxMind model classes""" + """Shared methods for MaxMind model classes.""" - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.to_dict() == other.to_dict() - def __ne__(self, other): + def __ne__(self, other) -> bool: return not self.__eq__(other) # pylint: disable=too-many-branches - def to_dict(self): - """Returns a dict of the object suitable for serialization""" + def to_dict(self) -> dict: + """Returns a dict of the object suitable for serialization.""" result = {} for key, value in self.__dict__.items(): if key.startswith("_"): diff --git a/geoip2/database.py b/geoip2/database.py index 4652d7d0..6c975a3d 100644 --- a/geoip2/database.py +++ b/geoip2/database.py @@ -7,41 +7,41 @@ import inspect import os -from typing import Any, AnyStr, cast, IO, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import IO, Any, AnyStr, Optional, Union, cast import maxminddb - from maxminddb import ( MODE_AUTO, - MODE_MMAP, - MODE_MMAP_EXT, + MODE_FD, MODE_FILE, MODE_MEMORY, - MODE_FD, + MODE_MMAP, + MODE_MMAP_EXT, ) import geoip2 -import geoip2.models import geoip2.errors -from geoip2.types import IPAddress +import geoip2.models from geoip2.models import ( ASN, + ISP, AnonymousIP, City, ConnectionType, Country, Domain, Enterprise, - ISP, ) +from geoip2.types import IPAddress __all__ = [ "MODE_AUTO", - "MODE_MMAP", - "MODE_MMAP_EXT", + "MODE_FD", "MODE_FILE", "MODE_MEMORY", - "MODE_FD", + "MODE_MMAP", + "MODE_MMAP_EXT", "Reader", ] @@ -135,9 +135,9 @@ def country(self, ip_address: IPAddress) -> Country: :returns: :py:class:`geoip2.models.Country` object """ - return cast( - Country, self._model_for(geoip2.models.Country, "Country", ip_address) + Country, + self._model_for(geoip2.models.Country, "Country", ip_address), ) def city(self, ip_address: IPAddress) -> City: @@ -161,7 +161,9 @@ def anonymous_ip(self, ip_address: IPAddress) -> AnonymousIP: return cast( AnonymousIP, self._flat_model_for( - geoip2.models.AnonymousIP, "GeoIP2-Anonymous-IP", ip_address + geoip2.models.AnonymousIP, + "GeoIP2-Anonymous-IP", + ip_address, ), ) @@ -174,7 +176,8 @@ def asn(self, ip_address: IPAddress) -> ASN: """ return cast( - ASN, self._flat_model_for(geoip2.models.ASN, "GeoLite2-ASN", ip_address) + ASN, + self._flat_model_for(geoip2.models.ASN, "GeoLite2-ASN", ip_address), ) def connection_type(self, ip_address: IPAddress) -> ConnectionType: @@ -188,7 +191,9 @@ def connection_type(self, ip_address: IPAddress) -> ConnectionType: return cast( ConnectionType, self._flat_model_for( - geoip2.models.ConnectionType, "GeoIP2-Connection-Type", ip_address + geoip2.models.ConnectionType, + "GeoIP2-Connection-Type", + ip_address, ), ) @@ -227,7 +232,8 @@ def isp(self, ip_address: IPAddress) -> ISP: """ return cast( - ISP, self._flat_model_for(geoip2.models.ISP, "GeoIP2-ISP", ip_address) + ISP, + self._flat_model_for(geoip2.models.ISP, "GeoIP2-ISP", ip_address), ) def _get(self, database_type: str, ip_address: IPAddress) -> Any: @@ -247,19 +253,26 @@ def _get(self, database_type: str, ip_address: IPAddress) -> Any: def _model_for( self, - model_class: Union[Type[Country], Type[Enterprise], Type[City]], + model_class: Union[type[Country], type[Enterprise], type[City]], types: str, ip_address: IPAddress, ) -> Union[Country, Enterprise, City]: (record, prefix_len) = self._get(types, ip_address) return model_class( - self._locales, ip_address=ip_address, prefix_len=prefix_len, **record + self._locales, + ip_address=ip_address, + prefix_len=prefix_len, + **record, ) def _flat_model_for( self, model_class: Union[ - Type[Domain], Type[ISP], Type[ConnectionType], Type[ASN], Type[AnonymousIP] + type[Domain], + type[ISP], + type[ConnectionType], + type[ASN], + type[AnonymousIP], ], types: str, ip_address: IPAddress, @@ -278,5 +291,4 @@ def metadata( def close(self) -> None: """Closes the GeoIP2 database.""" - self._db_reader.close() diff --git a/geoip2/errors.py b/geoip2/errors.py index b3e15d30..71bb57bc 100644 --- a/geoip2/errors.py +++ b/geoip2/errors.py @@ -53,8 +53,7 @@ def __init__( @property def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: - """The network for the error""" - + """The network for the error.""" if self.ip_address is None or self._prefix_len is None: return None return ipaddress.ip_network(f"{self.ip_address}/{self._prefix_len}", False) diff --git a/geoip2/models.py b/geoip2/models.py index ac0199c7..dd738bbf 100644 --- a/geoip2/models.py +++ b/geoip2/models.py @@ -14,7 +14,9 @@ # pylint: disable=too-many-instance-attributes,too-few-public-methods,too-many-arguments import ipaddress from abc import ABCMeta -from typing import Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union +from ipaddress import IPv4Address, IPv6Address import geoip2.records from geoip2._internal import Model @@ -80,24 +82,26 @@ def __init__( self, locales: Optional[Sequence[str]], *, - continent: Optional[Dict] = None, - country: Optional[Dict] = None, + continent: Optional[dict] = None, + country: Optional[dict] = None, ip_address: Optional[IPAddress] = None, - maxmind: Optional[Dict] = None, + maxmind: Optional[dict] = None, prefix_len: Optional[int] = None, - registered_country: Optional[Dict] = None, - represented_country: Optional[Dict] = None, - traits: Optional[Dict] = None, + registered_country: Optional[dict] = None, + represented_country: Optional[dict] = None, + traits: Optional[dict] = None, **_, ) -> None: self._locales = locales self.continent = geoip2.records.Continent(locales, **(continent or {})) self.country = geoip2.records.Country(locales, **(country or {})) self.registered_country = geoip2.records.Country( - locales, **(registered_country or {}) + locales, + **(registered_country or {}), ) self.represented_country = geoip2.records.RepresentedCountry( - locales, **(represented_country or {}) + locales, + **(represented_country or {}), ) self.maxmind = geoip2.records.MaxMind(**(maxmind or {})) @@ -112,8 +116,8 @@ def __init__( def __repr__(self) -> str: return ( - f"{self.__module__}.{self.__class__.__name__}({repr(self._locales)}, " - f"{', '.join(f'{k}={repr(v)}' for k, v in self.to_dict().items())})" + f"{self.__module__}.{self.__class__.__name__}({self._locales!r}, " + f"{', '.join(f'{k}={v!r}' for k, v in self.to_dict().items())})" ) @@ -197,18 +201,18 @@ def __init__( self, locales: Optional[Sequence[str]], *, - city: Optional[Dict] = None, - continent: Optional[Dict] = None, - country: Optional[Dict] = None, - location: Optional[Dict] = None, + city: Optional[dict] = None, + continent: Optional[dict] = None, + country: Optional[dict] = None, + location: Optional[dict] = None, ip_address: Optional[IPAddress] = None, - maxmind: Optional[Dict] = None, - postal: Optional[Dict] = None, + maxmind: Optional[dict] = None, + postal: Optional[dict] = None, prefix_len: Optional[int] = None, - registered_country: Optional[Dict] = None, - represented_country: Optional[Dict] = None, - subdivisions: Optional[List[Dict]] = None, - traits: Optional[Dict] = None, + registered_country: Optional[dict] = None, + represented_country: Optional[dict] = None, + subdivisions: Optional[list[dict]] = None, + traits: Optional[dict] = None, **_, ) -> None: super().__init__( @@ -357,7 +361,7 @@ class Enterprise(City): class SimpleModel(Model, metaclass=ABCMeta): - """Provides basic methods for non-location models""" + """Provides basic methods for non-location models.""" _ip_address: IPAddress _network: Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]] @@ -387,22 +391,20 @@ def __repr__(self) -> str: f"{self.__module__}.{self.__class__.__name__}(" + repr(str(self._ip_address)) + ", " - + ", ".join(f"{k}={repr(v)}" for k, v in d.items()) + + ", ".join(f"{k}={v!r}" for k, v in d.items()) + ")" ) @property - def ip_address(self): - """The IP address for the record""" - if not isinstance( - self._ip_address, (ipaddress.IPv4Address, ipaddress.IPv6Address) - ): + def ip_address(self) -> Union[IPv4Address, IPv6Address]: + """The IP address for the record.""" + if not isinstance(self._ip_address, (IPv4Address, IPv6Address)): self._ip_address = ipaddress.ip_address(self._ip_address) return self._ip_address @property def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: - """The network for the record""" + """The network for the record.""" # This code is duplicated for performance reasons network = self._network if network is not None: @@ -469,7 +471,7 @@ class AnonymousIP(SimpleModel): The IP address used in the lookup. - :type: str + :type: ipaddress.IPv4Address or ipaddress.IPv6Address .. attribute:: network @@ -532,7 +534,7 @@ class ASN(SimpleModel): The IP address used in the lookup. - :type: str + :type: ipaddress.IPv4Address or ipaddress.IPv6Address .. attribute:: network @@ -585,7 +587,7 @@ class ConnectionType(SimpleModel): The IP address used in the lookup. - :type: str + :type: ipaddress.IPv4Address or ipaddress.IPv6Address .. attribute:: network @@ -626,7 +628,7 @@ class Domain(SimpleModel): The IP address used in the lookup. - :type: str + :type: ipaddress.IPv4Address or ipaddress.IPv6Address .. attribute:: network @@ -703,7 +705,7 @@ class ISP(ASN): The IP address used in the lookup. - :type: str + :type: ipaddress.IPv4Address or ipaddress.IPv6Address .. attribute:: network diff --git a/geoip2/records.py b/geoip2/records.py index 17c1f37b..155e20ea 100644 --- a/geoip2/records.py +++ b/geoip2/records.py @@ -11,9 +11,12 @@ # pylint:disable=R0903 from abc import ABCMeta -from typing import Dict, Optional, Type, Sequence, Union +from collections.abc import Sequence +from ipaddress import IPv4Address, IPv6Address +from typing import Optional, Union from geoip2._internal import Model +from geoip2.types import IPAddress class Record(Model, metaclass=ABCMeta): @@ -27,13 +30,13 @@ def __repr__(self) -> str: class PlaceRecord(Record, metaclass=ABCMeta): """All records with :py:attr:`names` subclass :py:class:`PlaceRecord`.""" - names: Dict[str, str] + names: dict[str, str] _locales: Sequence[str] def __init__( self, locales: Optional[Sequence[str]], - names: Optional[Dict[str, str]], + names: Optional[dict[str, str]], ) -> None: if locales is None: locales = ["en"] @@ -97,7 +100,7 @@ def __init__( *, confidence: Optional[int] = None, geoname_id: Optional[int] = None, - names: Optional[Dict[str, str]] = None, + names: Optional[dict[str, str]] = None, **_, ) -> None: self.confidence = confidence @@ -113,7 +116,6 @@ class Continent(PlaceRecord): Attributes: - .. attribute:: code A two character continent code like "NA" (North America) @@ -152,7 +154,7 @@ def __init__( *, code: Optional[str] = None, geoname_id: Optional[int] = None, - names: Optional[Dict[str, str]] = None, + names: Optional[dict[str, str]] = None, **_, ) -> None: self.code = code @@ -167,7 +169,6 @@ class Country(PlaceRecord): Attributes: - .. attribute:: confidence A value from 0-100 indicating MaxMind's confidence that @@ -225,7 +226,7 @@ def __init__( geoname_id: Optional[int] = None, is_in_european_union: bool = False, iso_code: Optional[str] = None, - names: Optional[Dict[str, str]] = None, + names: Optional[dict[str, str]] = None, **_, ) -> None: self.confidence = confidence @@ -244,7 +245,6 @@ class RepresentedCountry(Country): Attributes: - .. attribute:: confidence A value from 0-100 indicating MaxMind's confidence that @@ -307,7 +307,7 @@ def __init__( geoname_id: Optional[int] = None, is_in_european_union: bool = False, iso_code: Optional[str] = None, - names: Optional[Dict[str, str]] = None, + names: Optional[dict[str, str]] = None, # pylint:disable=redefined-builtin type: Optional[str] = None, **_, @@ -470,7 +470,11 @@ class Postal(Record): confidence: Optional[int] def __init__( - self, *, code: Optional[str] = None, confidence: Optional[int] = None, **_ + self, + *, + code: Optional[str] = None, + confidence: Optional[int] = None, + **_, ) -> None: self.code = code self.confidence = confidence @@ -534,7 +538,7 @@ def __init__( confidence: Optional[int] = None, geoname_id: Optional[int] = None, iso_code: Optional[str] = None, - names: Optional[Dict[str, str]] = None, + names: Optional[dict[str, str]] = None, **_, ) -> None: self.confidence = confidence @@ -556,11 +560,12 @@ class Subdivisions(tuple): """ def __new__( - cls: Type["Subdivisions"], locales: Optional[Sequence[str]], *subdivisions + cls: type["Subdivisions"], + locales: Optional[Sequence[str]], + *subdivisions, ) -> "Subdivisions": subobjs = tuple(Subdivision(locales, **x) for x in subdivisions) - obj = super().__new__(cls, subobjs) # type: ignore - return obj + return super().__new__(cls, subobjs) # type: ignore def __init__( self, @@ -646,7 +651,7 @@ class Traits(Record): running on. If the system is behind a NAT, this may differ from the IP address locally assigned to it. - :type: str + :type: ipaddress.IPv4Address or ipaddress.IPv6Address .. attribute:: is_anonymous @@ -838,7 +843,7 @@ class Traits(Record): autonomous_system_organization: Optional[str] connection_type: Optional[str] domain: Optional[str] - _ip_address: Optional[str] + _ip_address: Optional[IPAddress] is_anonymous: bool is_anonymous_proxy: bool is_anonymous_vpn: bool @@ -920,17 +925,20 @@ def __init__( self._prefix_len = prefix_len @property - def ip_address(self): - """The IP address for the record""" - if not isinstance( - self._ip_address, (ipaddress.IPv4Address, ipaddress.IPv6Address) - ): - self._ip_address = ipaddress.ip_address(self._ip_address) - return self._ip_address + def ip_address(self) -> Optional[Union[IPv4Address, IPv6Address]]: + """The IP address for the record.""" + ip_address = self._ip_address + if ip_address is None: + return None + + if not isinstance(ip_address, (IPv4Address, IPv6Address)): + ip_address = ipaddress.ip_address(ip_address) + self._ip_address = ip_address + return ip_address @property def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]: - """The network for the record""" + """The network for the record.""" # This code is duplicated for performance reasons network = self._network if network is not None: diff --git a/geoip2/types.py b/geoip2/types.py index ba6d2b52..d86f1c0e 100644 --- a/geoip2/types.py +++ b/geoip2/types.py @@ -1,4 +1,4 @@ -"""Provides types used internally""" +"""Provides types used internally.""" from ipaddress import IPv4Address, IPv6Address from typing import Union diff --git a/geoip2/webservice.py b/geoip2/webservice.py index 3158d735..7b1a9e1b 100644 --- a/geoip2/webservice.py +++ b/geoip2/webservice.py @@ -27,7 +27,8 @@ import ipaddress import json -from typing import Any, Dict, cast, Optional, Sequence, Type, Union +from collections.abc import Sequence +from typing import Any, Optional, Union, cast import aiohttp import aiohttp.http @@ -106,7 +107,11 @@ def _handle_success(body: str, uri: str) -> Any: ) from ex def _exception_for_error( - self, status: int, content_type: str, body: str, uri: str + self, + status: int, + content_type: str, + body: str, + uri: str, ) -> GeoIP2Error: if 400 <= status < 500: return self._exception_for_4xx_status(status, content_type, body, uri) @@ -115,7 +120,11 @@ def _exception_for_error( return self._exception_for_non_200_status(status, uri, body) def _exception_for_4xx_status( - self, status: int, content_type: str, body: str, uri: str + self, + status: int, + content_type: str, + body: str, + uri: str, ) -> GeoIP2Error: if not body: return HTTPError( @@ -145,7 +154,10 @@ def _exception_for_4xx_status( if "code" in decoded_body and "error" in decoded_body: return self._exception_for_web_service_error( - decoded_body.get("error"), decoded_body.get("code"), status, uri + decoded_body.get("error"), + decoded_body.get("code"), + status, + uri, ) return HTTPError( "Response contains JSON but it does not specify code or error keys", @@ -156,7 +168,10 @@ def _exception_for_4xx_status( @staticmethod def _exception_for_web_service_error( - message: str, code: str, status: int, uri: str + message: str, + code: str, + status: int, + uri: str, ) -> Union[ AuthenticationError, AddressNotFoundError, @@ -184,7 +199,9 @@ def _exception_for_web_service_error( @staticmethod def _exception_for_5xx_status( - status: int, uri: str, body: Optional[str] + status: int, + uri: str, + body: Optional[str], ) -> HTTPError: return HTTPError( f"Received a server error ({status}) for {uri}", @@ -195,7 +212,9 @@ def _exception_for_5xx_status( @staticmethod def _exception_for_non_200_status( - status: int, uri: str, body: Optional[str] + status: int, + uri: str, + body: Optional[str], ) -> HTTPError: return HTTPError( f"Received a very surprising HTTP status ({status}) for {uri}", @@ -289,7 +308,8 @@ async def city(self, ip_address: IPAddress = "me") -> City: """ return cast( - City, await self._response_for("city", geoip2.models.City, ip_address) + City, + await self._response_for("city", geoip2.models.City, ip_address), ) async def country(self, ip_address: IPAddress = "me") -> Country: @@ -338,7 +358,7 @@ async def _session(self) -> aiohttp.ClientSession: async def _response_for( self, path: str, - model_class: Union[Type[Insights], Type[City], Type[Country]], + model_class: Union[type[Insights], type[City], type[Country]], ip_address: IPAddress, ) -> Union[Country, City, Insights]: uri = self._uri(path, ip_address) @@ -352,8 +372,8 @@ async def _response_for( decoded_body = self._handle_success(body, uri) return model_class(self._locales, **decoded_body) - async def close(self): - """Close underlying session + async def close(self) -> None: + """Close underlying session. This will close the session and any associated connections. """ @@ -421,7 +441,7 @@ class Client(BaseClient): """ _session: requests.Session - _proxies: Optional[Dict[str, str]] + _proxies: Optional[dict[str, str]] def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments self, @@ -465,7 +485,8 @@ def country(self, ip_address: IPAddress = "me") -> Country: """ return cast( - Country, self._response_for("country", geoip2.models.Country, ip_address) + Country, + self._response_for("country", geoip2.models.Country, ip_address), ) def insights(self, ip_address: IPAddress = "me") -> Insights: @@ -482,13 +503,14 @@ def insights(self, ip_address: IPAddress = "me") -> Insights: """ return cast( - Insights, self._response_for("insights", geoip2.models.Insights, ip_address) + Insights, + self._response_for("insights", geoip2.models.Insights, ip_address), ) def _response_for( self, path: str, - model_class: Union[Type[Insights], Type[City], Type[Country]], + model_class: Union[type[Insights], type[City], type[Country]], ip_address: IPAddress, ) -> Union[Country, City, Insights]: uri = self._uri(path, ip_address) @@ -501,8 +523,8 @@ def _response_for( decoded_body = self._handle_success(body, uri) return model_class(self._locales, **decoded_body) - def close(self): - """Close underlying session + def close(self) -> None: + """Close underlying session. This will close the session and any associated connections. """ diff --git a/pyproject.toml b/pyproject.toml index d55b8924..b77899cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "geoip2" -version = "4.8.1" +version = "5.0.1" description = "MaxMind GeoIP2 API" authors = [ {name = "Gregory Oschwald", email = "goschwald@maxmind.com"}, @@ -39,6 +39,33 @@ test = [ "pytest-httpserver>=1.0.10", ] +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + # Skip type annotation on **_ + "ANN003", + + # documenting magic methods + "D105", + + # Line length. We let black handle this for now. + "E501", + + # Don't bother with future imports for type annotations + "FA100", + + # Magic numbers for HTTP status codes seem ok most of the time. + "PLR2004", + + # pytest rules + "PT009", + "PT027", +] + +[tool.ruff.lint.per-file-ignores] +"geoip2/{models,records}.py" = [ "D107", "PLR0913" ] +"tests/*" = ["ANN201", "D"] + [tool.setuptools.package-data] geoip2 = ["py.typed"] diff --git a/tests/database_test.py b/tests/database_test.py index 3008f574..9f6daed4 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -1,18 +1,17 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -from __future__ import unicode_literals import ipaddress import sys import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch sys.path.append("..") +import maxminddb + import geoip2.database import geoip2.errors -import maxminddb try: import maxminddb.extension @@ -48,7 +47,7 @@ def test_unknown_address_network(self) -> None: except geoip2.errors.AddressNotFoundError as e: self.assertEqual(e.network, ipaddress.ip_network("10.0.0.0/8")) except Exception as e: - self.fail(f"Expected AddressNotFoundError, got {type(e)}: {str(e)}") + self.fail(f"Expected AddressNotFoundError, got {type(e)}: {e!s}") finally: reader.close() @@ -64,14 +63,15 @@ def test_wrong_database(self) -> None: def test_invalid_address(self) -> None: reader = geoip2.database.Reader("tests/data/test-data/GeoIP2-City-Test.mmdb") with self.assertRaisesRegex( - ValueError, "u?'invalid' does not appear to be an IPv4 or IPv6 address" + ValueError, + "u?'invalid' does not appear to be an IPv4 or IPv6 address", ): reader.city("invalid") reader.close() def test_anonymous_ip(self) -> None: reader = geoip2.database.Reader( - "tests/data/test-data/GeoIP2-Anonymous-IP-Test.mmdb" + "tests/data/test-data/GeoIP2-Anonymous-IP-Test.mmdb", ) ip_address = "1.2.0.1" @@ -88,7 +88,7 @@ def test_anonymous_ip(self) -> None: def test_anonymous_ip_all_set(self) -> None: reader = geoip2.database.Reader( - "tests/data/test-data/GeoIP2-Anonymous-IP-Test.mmdb" + "tests/data/test-data/GeoIP2-Anonymous-IP-Test.mmdb", ) ip_address = "81.2.69.1" @@ -129,11 +129,15 @@ def test_city(self) -> None: record = reader.city("81.2.69.160") self.assertEqual( - record.country.name, "United Kingdom", "The default locale is en" + record.country.name, + "United Kingdom", + "The default locale is en", ) self.assertEqual(record.country.is_in_european_union, False) self.assertEqual( - record.location.accuracy_radius, 100, "The accuracy_radius is populated" + record.location.accuracy_radius, + 100, + "The accuracy_radius is populated", ) self.assertEqual(record.registered_country.is_in_european_union, False) self.assertFalse(record.traits.is_anycast) @@ -145,14 +149,16 @@ def test_city(self) -> None: def test_connection_type(self) -> None: reader = geoip2.database.Reader( - "tests/data/test-data/GeoIP2-Connection-Type-Test.mmdb" + "tests/data/test-data/GeoIP2-Connection-Type-Test.mmdb", ) ip_address = "1.0.1.0" record = reader.connection_type(ip_address) self.assertEqual( - record, eval(repr(record)), "ConnectionType repr can be eval'd" + record, + eval(repr(record)), + "ConnectionType repr can be eval'd", ) self.assertEqual(record.connection_type, "Cellular") @@ -207,7 +213,7 @@ def test_domain(self) -> None: def test_enterprise(self) -> None: with geoip2.database.Reader( - "tests/data/test-data/GeoIP2-Enterprise-Test.mmdb" + "tests/data/test-data/GeoIP2-Enterprise-Test.mmdb", ) as reader: ip_address = "74.209.24.0" record = reader.enterprise(ip_address) @@ -221,7 +227,8 @@ def test_enterprise(self) -> None: self.assertTrue(record.traits.is_legitimate_proxy) self.assertEqual(record.traits.ip_address, ipaddress.ip_address(ip_address)) self.assertEqual( - record.traits.network, ipaddress.ip_network("74.209.16.0/20") + record.traits.network, + ipaddress.ip_network("74.209.16.0/20"), ) self.assertFalse(record.traits.is_anycast) @@ -234,7 +241,7 @@ def test_enterprise(self) -> None: def test_isp(self) -> None: with geoip2.database.Reader( - "tests/data/test-data/GeoIP2-ISP-Test.mmdb" + "tests/data/test-data/GeoIP2-ISP-Test.mmdb", ) as reader: ip_address = "1.128.0.0" record = reader.isp(ip_address) @@ -260,11 +267,12 @@ def test_isp(self) -> None: def test_context_manager(self) -> None: with geoip2.database.Reader( - "tests/data/test-data/GeoIP2-Country-Test.mmdb" + "tests/data/test-data/GeoIP2-Country-Test.mmdb", ) as reader: record = reader.country("81.2.69.160") self.assertEqual( - record.traits.ip_address, ipaddress.ip_address("81.2.69.160") + record.traits.ip_address, + ipaddress.ip_address("81.2.69.160"), ) @patch("maxminddb.open_database") @@ -275,5 +283,5 @@ def test_modes(self, mock_open) -> None: with geoip2.database.Reader( path, mode=geoip2.database.MODE_MMAP_EXT, - ) as reader: + ): mock_open.assert_called_once_with(path, geoip2.database.MODE_MMAP_EXT) diff --git a/tests/models_test.py b/tests/models_test.py index 3f72ec27..58bff6e8 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -1,11 +1,8 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -import sys import ipaddress -from typing import Dict +import sys import unittest sys.path.append("..") @@ -14,7 +11,7 @@ class TestModels(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = 20_000 def test_insights_full(self) -> None: @@ -98,10 +95,14 @@ def test_insights_full(self) -> None: model = geoip2.models.Insights(["en"], **raw) # type: ignore self.assertEqual( - type(model), geoip2.models.Insights, "geoip2.models.Insights object" + type(model), + geoip2.models.Insights, + "geoip2.models.Insights object", ) self.assertEqual( - type(model.city), geoip2.records.City, "geoip2.records.City object" + type(model.city), + geoip2.records.City, + "geoip2.records.City object", ) self.assertEqual( type(model.continent), @@ -109,7 +110,9 @@ def test_insights_full(self) -> None: "geoip2.records.Continent object", ) self.assertEqual( - type(model.country), geoip2.records.Country, "geoip2.records.Country object" + type(model.country), + geoip2.records.Country, + "geoip2.records.Country object", ) self.assertEqual( type(model.registered_country), @@ -132,23 +135,35 @@ def test_insights_full(self) -> None: "geoip2.records.Subdivision object", ) self.assertEqual( - type(model.traits), geoip2.records.Traits, "geoip2.records.Traits object" + type(model.traits), + geoip2.records.Traits, + "geoip2.records.Traits object", ) self.assertEqual(model.to_dict(), raw, "to_dict() method matches raw input") self.assertEqual( - model.subdivisions[0].iso_code, "MN", "div 1 has correct iso_code" + model.subdivisions[0].iso_code, + "MN", + "div 1 has correct iso_code", ) self.assertEqual( - model.subdivisions[0].confidence, 88, "div 1 has correct confidence" + model.subdivisions[0].confidence, + 88, + "div 1 has correct confidence", ) self.assertEqual( - model.subdivisions[0].geoname_id, 574635, "div 1 has correct geoname_id" + model.subdivisions[0].geoname_id, + 574635, + "div 1 has correct geoname_id", ) self.assertEqual( - model.subdivisions[0].names, {"en": "Minnesota"}, "div 1 names are correct" + model.subdivisions[0].names, + {"en": "Minnesota"}, + "div 1 names are correct", ) self.assertEqual( - model.subdivisions[1].name, "Hennepin", "div 2 has correct name" + model.subdivisions[1].name, + "Hennepin", + "div 2 has correct name", ) self.assertEqual( model.subdivisions.most_specific.iso_code, @@ -170,7 +185,9 @@ def test_insights_full(self) -> None: self.assertEqual(model.location.longitude, 93.2636, "correct longitude") self.assertEqual(model.location.metro_code, 765, "correct metro_code") self.assertEqual( - model.location.population_density, 1341, "correct population_density" + model.location.population_density, + 1341, + "correct population_density", ) self.assertRegex( @@ -188,7 +205,9 @@ def test_insights_full(self) -> None: ) self.assertEqual( - model.location, eval(repr(model.location)), "Location repr can be eval'd" + model.location, + eval(repr(model.location)), + "Location repr can be eval'd", ) self.assertIs(model.country.is_in_european_union, False) @@ -210,10 +229,14 @@ def test_insights_full(self) -> None: def test_insights_min(self) -> None: model = geoip2.models.Insights(["en"], traits={"ip_address": "5.6.7.8"}) self.assertEqual( - type(model), geoip2.models.Insights, "geoip2.models.Insights object" + type(model), + geoip2.models.Insights, + "geoip2.models.Insights object", ) self.assertEqual( - type(model.city), geoip2.records.City, "geoip2.records.City object" + type(model.city), + geoip2.records.City, + "geoip2.records.City object", ) self.assertEqual( type(model.continent), @@ -221,7 +244,9 @@ def test_insights_min(self) -> None: "geoip2.records.Continent object", ) self.assertEqual( - type(model.country), geoip2.records.Country, "geoip2.records.Country object" + type(model.country), + geoip2.records.Country, + "geoip2.records.Country object", ) self.assertEqual( type(model.registered_country), @@ -234,7 +259,9 @@ def test_insights_min(self) -> None: "geoip2.records.Location object", ) self.assertEqual( - type(model.traits), geoip2.records.Traits, "geoip2.records.Traits object" + type(model.traits), + geoip2.records.Traits, + "geoip2.records.Traits object", ) self.assertEqual( type(model.subdivisions.most_specific), @@ -242,7 +269,9 @@ def test_insights_min(self) -> None: "geoip2.records.Subdivision object returned even when none are available.", ) self.assertEqual( - model.subdivisions.most_specific.names, {}, "Empty names hash returned" + model.subdivisions.most_specific.names, + {}, + "Empty names hash returned", ) def test_city_full(self) -> None: @@ -270,7 +299,9 @@ def test_city_full(self) -> None: model = geoip2.models.City(["en"], **raw) # type: ignore self.assertEqual(type(model), geoip2.models.City, "geoip2.models.City object") self.assertEqual( - type(model.city), geoip2.records.City, "geoip2.records.City object" + type(model.city), + geoip2.records.City, + "geoip2.records.City object", ) self.assertEqual( type(model.continent), @@ -278,7 +309,9 @@ def test_city_full(self) -> None: "geoip2.records.Continent object", ) self.assertEqual( - type(model.country), geoip2.records.Country, "geoip2.records.Country object" + type(model.country), + geoip2.records.Country, + "geoip2.records.Country object", ) self.assertEqual( type(model.registered_country), @@ -291,18 +324,26 @@ def test_city_full(self) -> None: "geoip2.records.Location object", ) self.assertEqual( - type(model.traits), geoip2.records.Traits, "geoip2.records.Traits object" + type(model.traits), + geoip2.records.Traits, + "geoip2.records.Traits object", ) self.assertEqual( - model.to_dict(), raw, "to_dict method output matches raw input" + model.to_dict(), + raw, + "to_dict method output matches raw input", ) self.assertEqual(model.continent.geoname_id, 42, "continent geoname_id is 42") self.assertEqual(model.continent.code, "NA", "continent code is NA") self.assertEqual( - model.continent.names, {"en": "North America"}, "continent names is correct" + model.continent.names, + {"en": "North America"}, + "continent names is correct", ) self.assertEqual( - model.continent.name, "North America", "continent name is correct" + model.continent.name, + "North America", + "continent name is correct", ) self.assertEqual(model.country.geoname_id, 1, "country geoname_id is 1") self.assertEqual(model.country.iso_code, "US", "country iso_code is US") @@ -312,11 +353,15 @@ def test_city_full(self) -> None: "country names is correct", ) self.assertEqual( - model.country.name, "United States of America", "country name is correct" + model.country.name, + "United States of America", + "country name is correct", ) self.assertEqual(model.country.confidence, None, "country confidence is None") self.assertEqual( - model.registered_country.iso_code, "CA", "registered_country iso_code is CA" + model.registered_country.iso_code, + "CA", + "registered_country iso_code is CA", ) self.assertEqual( model.registered_country.names, @@ -346,7 +391,8 @@ def test_city_full(self) -> None: self.assertEqual(model.to_dict(), raw, "to_dict method matches raw input") self.assertRegex( - str(model), r"^geoip2.models.City\(\[.*en.*\], .*geoname_id.*\)" + str(model), + r"^geoip2.models.City\(\[.*en.*\], .*geoname_id.*\)", ) self.assertFalse(model == True, "__eq__ does not blow up on weird input") @@ -393,12 +439,14 @@ def test_unknown_keys(self) -> None: with self.assertRaises(AttributeError): model.traits.invalid # type: ignore self.assertEqual( - model.traits.ip_address, ipaddress.ip_address("1.2.3.4"), "correct ip" + model.traits.ip_address, + ipaddress.ip_address("1.2.3.4"), + "correct ip", ) class TestNames(unittest.TestCase): - raw: Dict = { + raw: dict = { "continent": { "code": "NA", "geoname_id": 42, @@ -460,16 +508,22 @@ def test_two_locales(self) -> None: def test_unknown_locale(self) -> None: model = geoip2.models.Country(locales=["aa"], **self.raw) self.assertEqual( - model.continent.name, None, "continent name is undef (no Afar available)" + model.continent.name, + None, + "continent name is undef (no Afar available)", ) self.assertEqual( - model.country.name, None, "country name is in None (no Afar available)" + model.country.name, + None, + "country name is in None (no Afar available)", ) def test_german(self) -> None: model = geoip2.models.Country(locales=["de"], **self.raw) self.assertEqual( - model.continent.name, "Nordamerika", "Correct german name for continent" + model.continent.name, + "Nordamerika", + "Correct german name for continent", ) diff --git a/tests/webservice_test.py b/tests/webservice_test.py index 0d6cc496..c17f86b8 100644 --- a/tests/webservice_test.py +++ b/tests/webservice_test.py @@ -1,17 +1,17 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- import asyncio import copy import ipaddress import sys -from typing import cast, Dict import unittest -from pytest_httpserver import HeaderValueMatcher -import pytest_httpserver -import pytest +from abc import ABC, abstractmethod from collections import defaultdict +from typing import cast, Callable, Union +import pytest +import pytest_httpserver +from pytest_httpserver import HeaderValueMatcher sys.path.append("..") import geoip2 @@ -27,7 +27,10 @@ from geoip2.webservice import AsyncClient, Client -class TestBaseClient(unittest.TestCase): +class TestBaseClient(unittest.TestCase, ABC): + client: Union[AsyncClient, Client] + client_class: Callable[[int, str], Union[AsyncClient, Client]] + country = { "continent": {"code": "NA", "geoname_id": 42, "names": {"en": "North America"}}, "country": { @@ -51,10 +54,13 @@ class TestBaseClient(unittest.TestCase): # this is not a comprehensive representation of the # JSON from the server - insights = cast(Dict, copy.deepcopy(country)) + insights = cast(dict, copy.deepcopy(country)) insights["traits"]["user_count"] = 2 insights["traits"]["static_ip_score"] = 1.3 + @abstractmethod + def run_client(self, v): ... + def _content_type(self, endpoint): return ( "application/vnd.maxmind.com-" @@ -63,12 +69,13 @@ def _content_type(self, endpoint): ) @pytest.fixture(autouse=True) - def setup_httpserver(self, httpserver: pytest_httpserver.HTTPServer): + def setup_httpserver(self, httpserver: pytest_httpserver.HTTPServer) -> None: self.httpserver = httpserver - def test_country_ok(self): + def test_country_ok(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.4", method="GET" + "/geoip/v2.1/country/1.2.3.4", + method="GET", ).respond_with_json( self.country, status=200, @@ -76,12 +83,16 @@ def test_country_ok(self): ) country = self.run_client(self.client.country("1.2.3.4")) self.assertEqual( - type(country), geoip2.models.Country, "return value of client.country" + type(country), + geoip2.models.Country, + "return value of client.country", ) self.assertEqual(country.continent.geoname_id, 42, "continent geoname_id is 42") self.assertEqual(country.continent.code, "NA", "continent code is NA") self.assertEqual( - country.continent.name, "North America", "continent name is North America" + country.continent.name, + "North America", + "continent name is North America", ) self.assertEqual(country.country.geoname_id, 1, "country geoname_id is 1") self.assertIs( @@ -91,7 +102,9 @@ def test_country_ok(self): ) self.assertEqual(country.country.iso_code, "US", "country iso_code is US") self.assertEqual( - country.country.names, {"en": "United States of America"}, "country names" + country.country.names, + {"en": "United States of America"}, + "country names", ) self.assertEqual( country.country.name, @@ -99,7 +112,9 @@ def test_country_ok(self): "country name is United States of America", ) self.assertEqual( - country.maxmind.queries_remaining, 11, "queries_remaining is 11" + country.maxmind.queries_remaining, + 11, + "queries_remaining is 11", ) self.assertIs( country.registered_country.is_in_european_union, @@ -107,14 +122,17 @@ def test_country_ok(self): "registered_country is_in_european_union is True", ) self.assertEqual( - country.traits.network, ipaddress.ip_network("1.2.3.0/24"), "network" + country.traits.network, + ipaddress.ip_network("1.2.3.0/24"), + "network", ) self.assertTrue(country.traits.is_anycast) self.assertEqual(country.to_dict(), self.country, "raw response is correct") - def test_me(self): + def test_me(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/me", method="GET" + "/geoip/v2.1/country/me", + method="GET", ).respond_with_json( self.country, status=200, @@ -122,7 +140,9 @@ def test_me(self): ) implicit_me = self.run_client(self.client.country()) self.assertEqual( - type(implicit_me), geoip2.models.Country, "country() returns Country object" + type(implicit_me), + geoip2.models.Country, + "country() returns Country object", ) explicit_me = self.run_client(self.client.country()) self.assertEqual( @@ -131,9 +151,10 @@ def test_me(self): "country('me') returns Country object", ) - def test_200_error(self): + def test_200_error(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.1.1.1", method="GET" + "/geoip/v2.1/country/1.1.1.1", + method="GET", ).respond_with_data( "", status=200, @@ -141,32 +162,37 @@ def test_200_error(self): ) with self.assertRaisesRegex( - GeoIP2Error, "could not decode the response as JSON" + GeoIP2Error, + "could not decode the response as JSON", ): self.run_client(self.client.country("1.1.1.1")) - def test_bad_ip_address(self): + def test_bad_ip_address(self) -> None: with self.assertRaisesRegex( - ValueError, "'1.2.3' does not appear to be an IPv4 or IPv6 address" + ValueError, + "'1.2.3' does not appear to be an IPv4 or IPv6 address", ): self.run_client(self.client.country("1.2.3")) - def test_no_body_error(self): + def test_no_body_error(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.7", method="GET" + "/geoip/v2.1/country/1.2.3.7", + method="GET", ).respond_with_data( "", status=400, content_type=self._content_type("country"), ) with self.assertRaisesRegex( - HTTPError, "Received a 400 error for .* with no body" + HTTPError, + "Received a 400 error for .* with no body", ): self.run_client(self.client.country("1.2.3.7")) - def test_weird_body_error(self): + def test_weird_body_error(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.8", method="GET" + "/geoip/v2.1/country/1.2.3.8", + method="GET", ).respond_with_json( {"wierd": 42}, status=400, @@ -179,22 +205,25 @@ def test_weird_body_error(self): ): self.run_client(self.client.country("1.2.3.8")) - def test_bad_body_error(self): + def test_bad_body_error(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.9", method="GET" + "/geoip/v2.1/country/1.2.3.9", + method="GET", ).respond_with_data( "bad body", status=400, content_type=self._content_type("country"), ) with self.assertRaisesRegex( - HTTPError, "it did not include the expected JSON body" + HTTPError, + "it did not include the expected JSON body", ): self.run_client(self.client.country("1.2.3.9")) - def test_500_error(self): + def test_500_error(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.10", method="GET" + "/geoip/v2.1/country/1.2.3.10", + method="GET", ).respond_with_data( "", status=500, @@ -203,80 +232,84 @@ def test_500_error(self): with self.assertRaisesRegex(HTTPError, r"Received a server error \(500\) for"): self.run_client(self.client.country("1.2.3.10")) - def test_300_error(self): + def test_300_error(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.11", method="GET" + "/geoip/v2.1/country/1.2.3.11", + method="GET", ).respond_with_data( "", status=300, content_type=self._content_type("country"), ) with self.assertRaisesRegex( - HTTPError, r"Received a very surprising HTTP status \(300\) for" + HTTPError, + r"Received a very surprising HTTP status \(300\) for", ): self.run_client(self.client.country("1.2.3.11")) - def test_ip_address_required(self): + def test_ip_address_required(self) -> None: self._test_error(400, "IP_ADDRESS_REQUIRED", InvalidRequestError) - def test_ip_address_not_found(self): + def test_ip_address_not_found(self) -> None: self._test_error(404, "IP_ADDRESS_NOT_FOUND", AddressNotFoundError) - def test_ip_address_reserved(self): + def test_ip_address_reserved(self) -> None: self._test_error(400, "IP_ADDRESS_RESERVED", AddressNotFoundError) - def test_permission_required(self): + def test_permission_required(self) -> None: self._test_error(403, "PERMISSION_REQUIRED", PermissionRequiredError) - def test_auth_invalid(self): + def test_auth_invalid(self) -> None: self._test_error(400, "AUTHORIZATION_INVALID", AuthenticationError) - def test_license_key_required(self): + def test_license_key_required(self) -> None: self._test_error(401, "LICENSE_KEY_REQUIRED", AuthenticationError) - def test_account_id_required(self): + def test_account_id_required(self) -> None: self._test_error(401, "ACCOUNT_ID_REQUIRED", AuthenticationError) - def test_user_id_required(self): + def test_user_id_required(self) -> None: self._test_error(401, "USER_ID_REQUIRED", AuthenticationError) - def test_account_id_unkown(self): + def test_account_id_unkown(self) -> None: self._test_error(401, "ACCOUNT_ID_UNKNOWN", AuthenticationError) - def test_user_id_unkown(self): + def test_user_id_unkown(self) -> None: self._test_error(401, "USER_ID_UNKNOWN", AuthenticationError) - def test_out_of_queries_error(self): + def test_out_of_queries_error(self) -> None: self._test_error(402, "OUT_OF_QUERIES", OutOfQueriesError) - def _test_error(self, status, error_code, error_class): + def _test_error(self, status, error_code, error_class) -> None: msg = "Some error message" body = {"error": msg, "code": error_code} self.httpserver.expect_request( - "/geoip/v2.1/country/1.2.3.18", method="GET" + "/geoip/v2.1/country/1.2.3.18", + method="GET", ).respond_with_json( body, status=status, content_type=self._content_type("country"), ) - with self.assertRaisesRegex(error_class, msg): + with pytest.raises(error_class, match=msg): self.run_client(self.client.country("1.2.3.18")) - def test_unknown_error(self): + def test_unknown_error(self) -> None: msg = "Unknown error type" ip = "1.2.3.19" body = {"error": msg, "code": "UNKNOWN_TYPE"} self.httpserver.expect_request( - "/geoip/v2.1/country/" + ip, method="GET" + "/geoip/v2.1/country/" + ip, + method="GET", ).respond_with_json( body, status=400, content_type=self._content_type("country"), ) - with self.assertRaisesRegex(InvalidRequestError, msg): + with pytest.raises(InvalidRequestError, match=msg): self.run_client(self.client.country(ip)) - def test_request(self): + def test_request(self) -> None: def user_agent_compare(actual: str, expected: str) -> bool: if actual is None: return False @@ -293,8 +326,8 @@ def user_agent_compare(actual: str, expected: str) -> bool: header_value_matcher=HeaderValueMatcher( defaultdict( lambda: HeaderValueMatcher.default_header_value_matcher, - {"User-Agent": user_agent_compare}, - ) + {"User-Agent": user_agent_compare}, # type: ignore[dict-item] + ), ), ).respond_with_json( self.country, @@ -303,9 +336,10 @@ def user_agent_compare(actual: str, expected: str) -> bool: ) self.run_client(self.client.country("1.2.3.4")) - def test_city_ok(self): + def test_city_ok(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/city/1.2.3.4", method="GET" + "/geoip/v2.1/city/1.2.3.4", + method="GET", ).respond_with_json( self.country, status=200, @@ -314,13 +348,16 @@ def test_city_ok(self): city = self.run_client(self.client.city("1.2.3.4")) self.assertEqual(type(city), geoip2.models.City, "return value of client.city") self.assertEqual( - city.traits.network, ipaddress.ip_network("1.2.3.0/24"), "network" + city.traits.network, + ipaddress.ip_network("1.2.3.0/24"), + "network", ) self.assertTrue(city.traits.is_anycast) - def test_insights_ok(self): + def test_insights_ok(self) -> None: self.httpserver.expect_request( - "/geoip/v2.1/insights/1.2.3.4", method="GET" + "/geoip/v2.1/insights/1.2.3.4", + method="GET", ).respond_with_json( self.insights, status=200, @@ -328,32 +365,39 @@ def test_insights_ok(self): ) insights = self.run_client(self.client.insights("1.2.3.4")) self.assertEqual( - type(insights), geoip2.models.Insights, "return value of client.insights" + type(insights), + geoip2.models.Insights, + "return value of client.insights", ) self.assertEqual( - insights.traits.network, ipaddress.ip_network("1.2.3.0/24"), "network" + insights.traits.network, + ipaddress.ip_network("1.2.3.0/24"), + "network", ) self.assertTrue(insights.traits.is_anycast) self.assertEqual(insights.traits.static_ip_score, 1.3, "static_ip_score is 1.3") self.assertEqual(insights.traits.user_count, 2, "user_count is 2") - def test_named_constructor_args(self): + def test_named_constructor_args(self) -> None: id = 47 key = "1234567890ab" - client = self.client_class(account_id=id, license_key=key) + client = self.client_class(id, key) self.assertEqual(client._account_id, str(id)) self.assertEqual(client._license_key, key) - def test_missing_constructor_args(self): + def test_missing_constructor_args(self) -> None: with self.assertRaises(TypeError): - self.client_class(license_key="1234567890ab") + + self.client_class(license_key="1234567890ab") # type: ignore[call-arg] with self.assertRaises(TypeError): - self.client_class("47") + self.client_class("47") # type: ignore class TestClient(TestBaseClient): - def setUp(self): + client: Client + + def setUp(self) -> None: self.client_class = Client self.client = Client(42, "abcdef123456") self.client._base_uri = self.httpserver.url_for("/geoip/v2.1") @@ -364,14 +408,16 @@ def run_client(self, v): class TestAsyncClient(TestBaseClient): - def setUp(self): + client: AsyncClient + + def setUp(self) -> None: self._loop = asyncio.new_event_loop() self.client_class = AsyncClient self.client = AsyncClient(42, "abcdef123456") self.client._base_uri = self.httpserver.url_for("/geoip/v2.1") self.maxDiff = 20_000 - def tearDown(self): + def tearDown(self) -> None: self._loop.run_until_complete(self.client.close()) self._loop.close()