diff --git a/components/polylith/interface/parser.py b/components/polylith/interface/parser.py index ab9a68e3..f9c1da3c 100644 --- a/components/polylith/interface/parser.py +++ b/components/polylith/interface/parser.py @@ -1,10 +1,13 @@ import ast from functools import lru_cache from pathlib import Path -from typing import Set, Union +from typing import FrozenSet, List, Set, Union from polylith.imports import SYMBOLS, extract_api, list_imports, parse_module +PACKAGE_INTERFACE = "__init__.py" +ALL_STATEMENT = "__all__" + def target_names(t: ast.AST) -> Set[str]: if isinstance(t, ast.Name): @@ -53,13 +56,25 @@ def extract_public_variables(path: Path) -> Set[str]: def is_the_all_statement(target: ast.expr) -> bool: - return isinstance(target, ast.Name) and target.id == "__all__" + return isinstance(target, ast.Name) and target.id == ALL_STATEMENT def is_string_constant(expression: ast.AST) -> bool: return isinstance(expression, ast.Constant) and isinstance(expression.value, str) +def attribute_expr_to_parts(expr: ast.AST) -> List[str]: + if isinstance(expr, ast.Name): + return [expr.id] + + if isinstance(expr, ast.Attribute): + parent = attribute_expr_to_parts(expr.value) + + return [*parent, expr.attr] if parent else [] + + return [] + + def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]: if not isinstance(statement, ast.Assign): return None @@ -76,12 +91,69 @@ def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]: return {e.value for e in statement.value.elts if isinstance(e, ast.Constant)} -def extract_the_all_variable(path: Path) -> Set[str]: +def find_the_all_pointer(statement: ast.stmt) -> Union[str, None]: + if not isinstance(statement, ast.Assign): + return None + + if not any(is_the_all_statement(t) for t in statement.targets): + return None + + parts = attribute_expr_to_parts(statement.value) + + if not parts: + return None + + *module_path, rest = parts + + if rest != ALL_STATEMENT: + return None + + return ".".join(module_path) + + +def resolve_local_module_path(package_dir: Path, module_ref: str) -> Union[Path, None]: + parts = tuple(p for p in module_ref.split(".") if p) + + if not parts: + return None + + module_file = package_dir.joinpath(*parts).with_suffix(".py") + + if module_file.exists(): + return module_file + + module_init = package_dir.joinpath(*parts, PACKAGE_INTERFACE) + + return module_init if module_init.exists() else None + + +def _extract_the_all_variable(path: Path, visited: FrozenSet[Path]) -> Set[str]: + if path in visited: + return set() + + visited = visited | frozenset({path}) + tree = parse(path) - res = [find_the_all_variable(s) for s in tree.body] + literals = [find_the_all_variable(s) for s in tree.body] + literal = next((r for r in literals if r is not None), None) + + if literal is not None: + return literal + + pointers = (find_the_all_pointer(s) for s in tree.body) + pointer = next((p for p in pointers if p is not None), None) + + if not pointer: + return set() - return next((r for r in res if r is not None), set()) + resolved = resolve_local_module_path(path.parent, pointer) + + return _extract_the_all_variable(resolved, visited) if resolved else set() + + +def extract_the_all_variable(path: Path) -> Set[str]: + return _extract_the_all_variable(path, frozenset()) def extract_imported_api(path: Path) -> Set[str]: @@ -98,7 +170,7 @@ def fetch_api_for_path(path: Path) -> Set[str]: def fetch_api(paths: Set[Path]) -> dict: - interface_paths = [Path(p / "__init__.py") for p in paths] + interface_paths = [Path(p / PACKAGE_INTERFACE) for p in paths] interfaces = [p for p in interface_paths if p.exists()] diff --git a/test/components/polylith/interface/test_parse_api.py b/test/components/polylith/interface/test_parse_api.py index 954e4733..8bdddb8d 100644 --- a/test/components/polylith/interface/test_parse_api.py +++ b/test/components/polylith/interface/test_parse_api.py @@ -72,6 +72,29 @@ def test_extract_the_all_variable(monkeypatch) -> None: assert res == {"thing", "other", "message"} +def test_extract_the_all_variable_from_module_pointer(tmp_path) -> None: + # Use real files to exercise module resolution: + # comp/__init__.py: __all__ = core.__all__ + # comp/core.py: __all__ = ["pub_func"] + parser.parse.cache_clear() + + package_dir = tmp_path / "comp" + package_dir.mkdir(parents=True) + + init = package_dir / "__init__.py" + core = package_dir / "core.py" + + init.write_text("from .core import *\n\n__all__ = core.__all__\n") + core.write_text( + "__all__ = [\"pub_func\"]\n\n\ndef pub_func():\n pass\n" + ) + + res = parser.extract_the_all_variable(init) + + assert res == {"pub_func"} + assert parser.fetch_api_for_path(init) == {"pub_func"} + + def test_fetch_api_for_path(monkeypatch) -> None: fn = partial(fake_parse, the_interface)