Skip to content
Draft

wip #422

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 78 additions & 6 deletions components/polylith/interface/parser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ast
from functools import lru_cache
from pathlib import Path
from typing import Set, Union
from typing import FrozenSet, List, Set, Union

from polylith.imports import SYMBOLS, extract_api, list_imports, parse_module

PACKAGE_INTERFACE = "__init__.py"
ALL_STATEMENT = "__all__"


def target_names(t: ast.AST) -> Set[str]:
if isinstance(t, ast.Name):
Expand Down Expand Up @@ -53,13 +56,25 @@ def extract_public_variables(path: Path) -> Set[str]:


def is_the_all_statement(target: ast.expr) -> bool:
return isinstance(target, ast.Name) and target.id == "__all__"
return isinstance(target, ast.Name) and target.id == ALL_STATEMENT


def is_string_constant(expression: ast.AST) -> bool:
return isinstance(expression, ast.Constant) and isinstance(expression.value, str)


def attribute_expr_to_parts(expr: ast.AST) -> List[str]:
if isinstance(expr, ast.Name):
return [expr.id]

if isinstance(expr, ast.Attribute):
parent = attribute_expr_to_parts(expr.value)

return [*parent, expr.attr] if parent else []

return []


def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]:
if not isinstance(statement, ast.Assign):
return None
Expand All @@ -76,12 +91,69 @@ def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]:
return {e.value for e in statement.value.elts if isinstance(e, ast.Constant)}


def extract_the_all_variable(path: Path) -> Set[str]:
def find_the_all_pointer(statement: ast.stmt) -> Union[str, None]:
if not isinstance(statement, ast.Assign):
return None

if not any(is_the_all_statement(t) for t in statement.targets):
return None

parts = attribute_expr_to_parts(statement.value)

if not parts:
return None

*module_path, rest = parts

if rest != ALL_STATEMENT:
return None

return ".".join(module_path)


def resolve_local_module_path(package_dir: Path, module_ref: str) -> Union[Path, None]:
parts = tuple(p for p in module_ref.split(".") if p)

if not parts:
return None

module_file = package_dir.joinpath(*parts).with_suffix(".py")

if module_file.exists():
return module_file

module_init = package_dir.joinpath(*parts, PACKAGE_INTERFACE)

return module_init if module_init.exists() else None


def _extract_the_all_variable(path: Path, visited: FrozenSet[Path]) -> Set[str]:
if path in visited:
return set()

visited = visited | frozenset({path})

tree = parse(path)

res = [find_the_all_variable(s) for s in tree.body]
literals = [find_the_all_variable(s) for s in tree.body]
literal = next((r for r in literals if r is not None), None)

if literal is not None:
return literal

pointers = (find_the_all_pointer(s) for s in tree.body)
pointer = next((p for p in pointers if p is not None), None)

if not pointer:
return set()

return next((r for r in res if r is not None), set())
resolved = resolve_local_module_path(path.parent, pointer)

return _extract_the_all_variable(resolved, visited) if resolved else set()


def extract_the_all_variable(path: Path) -> Set[str]:
return _extract_the_all_variable(path, frozenset())


def extract_imported_api(path: Path) -> Set[str]:
Expand All @@ -98,7 +170,7 @@ def fetch_api_for_path(path: Path) -> Set[str]:


def fetch_api(paths: Set[Path]) -> dict:
interface_paths = [Path(p / "__init__.py") for p in paths]
interface_paths = [Path(p / PACKAGE_INTERFACE) for p in paths]

interfaces = [p for p in interface_paths if p.exists()]

Expand Down
23 changes: 23 additions & 0 deletions test/components/polylith/interface/test_parse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def test_extract_the_all_variable(monkeypatch) -> None:
assert res == {"thing", "other", "message"}


def test_extract_the_all_variable_from_module_pointer(tmp_path) -> None:
# Use real files to exercise module resolution:
# comp/__init__.py: __all__ = core.__all__
# comp/core.py: __all__ = ["pub_func"]
parser.parse.cache_clear()

package_dir = tmp_path / "comp"
package_dir.mkdir(parents=True)

init = package_dir / "__init__.py"
core = package_dir / "core.py"

init.write_text("from .core import *\n\n__all__ = core.__all__\n")
core.write_text(
"__all__ = [\"pub_func\"]\n\n\ndef pub_func():\n pass\n"
)

res = parser.extract_the_all_variable(init)

assert res == {"pub_func"}
assert parser.fetch_api_for_path(init) == {"pub_func"}


def test_fetch_api_for_path(monkeypatch) -> None:
fn = partial(fake_parse, the_interface)

Expand Down