diff --git a/docs/BaseComponent_ko.md b/docs/BaseComponent_ko.md new file mode 100644 index 0000000..c98635f --- /dev/null +++ b/docs/BaseComponent_ko.md @@ -0,0 +1,281 @@ +# BaseComponent + +`BaseComponent`는 **define-by-run(순수 파이썬 제어)** 철학을 유지하면서도, 컴포넌트 실행을 **관측 가능(observable)** 하게 만들기 위한 **선택적(opt-in) 표준 레이어**입니다. + +* 파이프라인은 `step(run: RunContext) -> RunContext` 형태의 **그냥 함수/콜러블**만으로도 충분히 동작합니다. +* `BaseComponent`는 그 위에 **추적(hooks), 에러 표준화, 이름/형식 통일**을 얹어주는 역할을 합니다. + +즉, **필수는 아니지만**, 라이브러리/팀 단위 개발에서 “운영 가능한 형태”로 만들고 싶을 때 유용합니다. + +--- + +## 왜 필요한가? + +### 1) 관측성(Tracing)을 “그래프 엔진 없이” 얻기 위해 + +Lang2SQL은 LangGraph 같은 그래프 엔진을 강제하지 않습니다. 대신: + +* 사용자는 Python `if/for/while`로 제어한다. +* 라이브러리는 관측성은 **hook 이벤트**로 제공한다. + +`BaseComponent`는 각 컴포넌트 실행의 `start/end/error`를 이벤트로 남깁니다. + +### 2) 에러를 “도메인 친화적으로” 정리하기 위해 + +현실에서는 `ValueError`, `KeyError`, 외부 라이브러리 예외 등이 섞여서 올라옵니다. + +`BaseComponent`는: + +* `Lang2SQLError`(ValidationError, IntegrationMissingError 등)는 **그대로 유지** +* 그 외 예외는 `ComponentError`로 **표준 래핑**(+ 원인 예외를 `cause`로 보존) + +→ 사용자/운영자 관점에서 “어디서 터졌는지”가 분명해집니다. + +### 3) “컴포넌트 단위 표준”을 만들기 위해 + +라이브러리 제공 컴포넌트를 모두 BaseComponent 기반으로 만들면: + +* 로그/트레이스의 포맷이 통일 +* 테스트/디버깅 경험이 일정 +* 문서/타입 힌트가 일관 + +--- + +## 철학: Define-by-run + Minimal core + +Lang2SQL의 기본 철학은 아래 2개입니다. + +1. **제어는 파이썬으로** + 루프/분기/재시도/서브플로우 호출은 “프레임워크 DSL”이 아니라 Python으로 표현합니다. + +2. **상태는 RunContext 하나로** + 파이프라인이 커져도, step 간 연결이 깨지지 않도록 `RunContext`를 I/O로 둡니다. + +`BaseComponent`는 이 철학을 해치지 않습니다. +컴포넌트의 실행을 감싸서 이벤트만 남길 뿐, 그래프/스키마/실행 모델을 강제하지 않습니다. + +--- + +## BaseComponent가 제공하는 API + +### 생성자 + +```python +BaseComponent(name: str | None = None, hook: TraceHook | None = None) +``` + +* `name`: 이벤트에 찍힐 컴포넌트 이름 (기본값: 클래스명) +* `hook`: 이벤트 수신자. 기본값은 `NullHook()` (아무것도 하지 않음) + +### 구현해야 하는 것: `run()` + +```python +class MyComp(BaseComponent): + def run(self, run: RunContext) -> RunContext: + ... + return run +``` + +### 실행: `__call__` + +`comp(run)`을 호출하면 내부적으로 아래를 자동 수행합니다. + +* `component.run start 이벤트 발행` +* `self.run(...)` 실행 +* 성공 시 `end 이벤트` + `duration_ms` +* 실패 시 `error 이벤트` + + * 도메인 예외(`Lang2SQLError`)는 그대로 raise + * 그 외 예외는 `ComponentError`로 래핑해서 raise + +--- + +## 권장 규약: RunContext in → RunContext out + +Lang2SQL의 기본 step 규약은 단순합니다. + +> **RunContext를 받으면 RunContext를 반환한다.** +> (`return run`을 습관처럼) + +왜냐하면 “None 반환”은 인간이 보기엔 자연스럽지만, 팀/사용자 관점에서는 실수를 만들기 쉽습니다. + +* `return None`은 “의도적”인지 “실수(반환 누락)”인지 구분이 안 됨 +* Flow/컴포넌트 조합에서 결과가 조용히 깨지기 쉬움 + +그래서 Lang2SQL은 **fail-fast** 스타일을 권장합니다. + +--- + +## 언제 BaseComponent를 쓰는가? + +### ✅ BaseComponent를 쓰는 게 좋은 경우 + +* 라이브러리 기본 제공 컴포넌트( retriever/builder/generator/validator ) +* 팀/제품 환경에서 **관측성(트레이싱)이 필요한 경우** +* 예외 표준화가 중요한 경우(운영/테스트/디버깅) + +### ✅ BaseComponent 없이 함수로 두는 게 좋은 경우 + +* `policy`, `eval`, metric 계산처럼 **순수 함수 성격**이 강한 로직 +* “유저가 빠르게 붙여 넣어 쓰는” 초경량 커스텀 로직 +* 실행 단위가 너무 작아 이벤트가 과도해지는 경우 + +즉, **핵심 파이프라인 축**은 BaseComponent로 잡고, +그 외의 작은 로직은 함수로 두는 혼합형이 가장 자연스럽습니다. + +--- + +## FunctionalComponent: “함수도 트레이싱하고 싶다” + +유저에게 “클래스 상속 + run 메서드 작성”이 부담인 경우가 많습니다. +그래서 **함수/콜러블을 그대로 유지하면서**도 트레이싱을 얻고 싶다면 래퍼를 제공합니다. + +### 예시: FunctionalComponent + +```python +from __future__ import annotations +from typing import Callable, Any, Optional + +from .base import BaseComponent +from .context import RunContext + +class FunctionalComponent(BaseComponent): + """ + Wrap a callable(run: RunContext) -> RunContext into a BaseComponent, + so it becomes traceable and error-normalized. + """ + + def __init__( + self, + fn: Callable[[RunContext], RunContext], + *, + name: str | None = None, + hook=None, + ) -> None: + super().__init__(name=name or getattr(fn, "__name__", "FunctionalComponent"), hook=hook) + self._fn = fn + + def run(self, run: RunContext) -> RunContext: + return self._fn(run) +``` + +### 사용 예 + +```python +def my_retriever(run: RunContext) -> RunContext: + run.schema_selected = ... + return run + +retriever = FunctionalComponent(my_retriever, name="MyRetriever", hook=hook) +``` + +> 이 방식의 장점: 유저는 “함수 스타일” 그대로 유지하면서, 운영/디버깅을 위한 트레이싱을 얻게 됩니다. + +--- + +## 훅(Tracing) 시스템이 뭐고, 왜 필요한가? + +### Hook이란? + +컴포넌트/플로우 실행 시점에 **이벤트(Event)** 를 받는 인터페이스입니다. + +* `start/end/error` 시점 기록 +* 소요 시간(duration_ms) +* 입력/출력 요약(input_summary/output_summary) +* 필요하면 `data`에 구조화된 값을 추가 + +### 어디서 확인하나? + +가장 쉬운 건 `MemoryHook`입니다. + +```python +from lang2sql.core.hooks import MemoryHook +hook = MemoryHook() + +flow = BaselineFlow(steps=[...], hook=hook) # 또는 컴포넌트마다 hook 주입 +out = flow.run_query("지난달 매출") + +# 이벤트 확인 +for e in hook.events: + print(e.phase, e.component, e.duration_ms, e.error) +``` + +### 운영용 관측성은 어디서 제어하나? + +운영에서는 `MemoryHook` 대신 다음이 일반적입니다. + +* 로그로 흘리는 Hook (stdout / JSON log) +* APM/Tracing으로 보내는 Hook (OpenTelemetry span 등) +* 필터링 Hook (특정 컴포넌트만 샘플링) + +핵심은: **관측성은 hook 구현체에서 제어**하고, 파이프라인/컴포넌트 코드는 최대한 “비즈니스 로직”만 갖도록 분리합니다. + +--- + +## 중첩(서브플로우/래핑)하면 트레이싱이 깨지나? + +“깨진다”기보다는 **이벤트가 더 많이 찍힙니다.** + +* `flow_b` 안에 `flow_a`를 step으로 넣으면 + + * `flow_b` 이벤트 2개(시작/끝) + * `flow_a` 이벤트 2개(시작/끝) + * `a1/a2` 컴포넌트 이벤트도 각각 찍힘(컴포넌트가 BaseComponent라면) + +이게 싫다면 두 가지 선택지가 있습니다. + +1. **상위 레벨(Flow)만 트레이싱하고 내부는 함수로 둔다** +2. **Hook에서 필터링/샘플링한다** (예: component 이름 prefix로 제외) + +추가 문법 없이 해결하려면 2번이 가장 현실적입니다. + +--- + +## 베스트 프랙티스 + +### 1) 구성(config)은 `__init__`에, 요청별 상태는 `RunContext`에 + +```python +class Retriever(BaseComponent): + def __init__(self, catalog, top_k=8, ...): + self.catalog = catalog # 고정 설정 + self.top_k = top_k + + def run(self, run: RunContext) -> RunContext: + # 요청마다 달라지는 값은 run에서 읽고 run에 쓴다 + ... + return run +``` + +### 2) RunContext가 들어오면 무조건 `return run` + +* 가독성(계약이 분명) +* 실수 방지(fail-fast) +* flow 합성 시 안정 + +### 3) “작은 로직(policy/eval)은 그냥 함수” + +* BaseComponent로 감싸는 건 선택 +* 운영에서 꼭 추적이 필요할 때만 FunctionalComponent로 감싼다 + +--- + +## FAQ + +### Q. “그냥 함수만 써도 되는데 왜 굳이 BaseComponent?” + +A. **운영/디버깅/협업에서** 차이가 큽니다. +문제 났을 때 “어디서, 어떤 입력으로, 얼마나 걸리다, 어떤 에러로” 터졌는지 자동으로 남는 게 핵심 가치입니다. + +### Q. “BaseComponent를 유저가 직접 써야 하나?” + +A. 필수 아닙니다. +초급 유저는 **SequentialFlow + 프리셋 컴포넌트**만으로 충분히 쓰게 하고, +고급/운영 유저에게 BaseComponent/Hook을 제공하는 구성이 가장 자연스럽습니다. + +### Q. “policy는 RunContext를 몰라도 되는데?” + +A. 맞습니다. `policy(metrics) -> action` 같은 건 순수 함수로 두는 걸 권장합니다. +필요하면 `FunctionalComponent(policy_fn)`처럼 감싸서 추적만 추가할 수 있습니다. + +--- diff --git a/docs/BaseFlow_ko.md b/docs/BaseFlow_ko.md new file mode 100644 index 0000000..479e5bc --- /dev/null +++ b/docs/BaseFlow_ko.md @@ -0,0 +1,183 @@ +# BaseFlow + +`BaseFlow`는 Lang2SQL에서 **define-by-run(순수 파이썬 제어)** 철학을 구현하기 위한 “플로우의 최소 추상화(minimal abstraction)”입니다. + +* 파이프라인의 **제어권(control-flow)** 을 프레임워크 DSL이 아니라 **사용자 코드(Python)** 가 갖습니다. +* LangGraph 같은 그래프 엔진을 강제하지 않습니다. +* 대신, 실행 단위를 `Flow`로 묶고 **관측성(hooks)** 과 **에러 규약**을 통일합니다. + +--- + +## 왜 필요한가? + +### 1) “제어는 파이썬으로”를 지키기 위해 + +Text2SQL은 현실적으로 다음 제어가 자주 필요합니다. + +* 재시도 루프 (`while`, `for`) +* 조건 분기 (`if`, `match`) +* 부분 파이프라인(서브플로우) 호출 +* 정책(policy) 기반 행동 결정 + +`BaseFlow`는 이런 제어를 **사용자가 Python으로 직접 작성**하게 두고, 라이브러리는 “실행 컨테이너 + 관측성”만 제공합니다. + +### 2) 요청 단위 관측성(Flow-level tracing) + +운영/디버깅에서는 “이 요청 전체가 언제 시작했고, 어디서 실패했고, 얼마나 걸렸는지”가 먼저 중요합니다. + +`BaseFlow`는 다음 이벤트를 발행합니다. + +* `flow.run` start / end / error +* 실행 시간(`duration_ms`) + +→ 요청 1건을 **Flow 단위로 빠르게 파악**할 수 있습니다. + +### 3) 공통 엔트리포인트(run_query) 제공 + +Text2SQL은 대부분 “문장(query)”이 시작점입니다. + +`run_query("...")`를 제공하면: + +* 초급 사용자는 `RunContext`를 몰라도 “바로 실행” 가능 +* 고급 사용자는 `run(RunContext)`로 제어를 확장 가능 + +--- + +## BaseFlow가 제공하는 API + +### 1) 구현해야 하는 것: `run()` + +```python +class MyFlow(BaseFlow): + def run(self, run: RunContext) -> RunContext: + ... + return run +``` + +* Flow의 본체 로직은 여기에 작성합니다. +* 제어는 Python으로 직접 작성합니다. (`if/for/while`) + +### 2) 호출: `__call__` + +```python +out = flow(run) +``` + +* 내부적으로 `flow.run(...)`을 호출합니다. +* hook 이벤트를 `start/end/error`로 기록합니다. + +### 3) 편의 엔트리포인트: `run_query()` + +```python +out = flow.run_query("지난달 매출") +``` + +* 내부에서 `RunContext(query=...)`를 만들고 `run()`을 호출합니다. +* Quickstart / demo / 초급 UX용 엔트리포인트입니다. + +> 권장: **BaseFlow에 run_query를 둬서 “모든 Flow는 run_query가 된다”는 직관을 유지**합니다. + +--- + +## run(runcontext) vs run_query(query) + +둘은 기능적으로 **같은 동작**을 하도록 설계합니다. + +```python +out1 = flow.run_query("지난달 매출") +out2 = flow.run(RunContext(query="지난달 매출")) +``` + +* `run_query(query)` : 문자열 query에서 시작하는 편의 API +* `run(runcontext)` : 고급 사용자를 위한 명시적 API + +--- + +## 사용 패턴 + +### 1) 초급: SequentialFlow로 구성하고 run_query로 실행 + +초급 사용자는 보통 “구성만 하고 실행”하면 됩니다. + +```python +flow = SequentialFlow(steps=[retriever, builder, generator, validator]) +out = flow.run_query("지난달 매출") +``` + +### 2) 고급: CustomFlow로 제어(while/if/policy) + +정책/루프/재시도 같은 제어가 들어오면 `BaseFlow`를 직접 상속해 작성하는 것이 가장 깔끔합니다. + +```python +class RetryFlow(BaseFlow): + def run(self, run: RunContext) -> RunContext: + while True: + run = retriever(run) + metrics = eval_retrieval(run) # 순수 함수 가능 + action = policy(metrics) # 순수 함수 가능 + if action == "retry": + continue + break + + run = generator(run) + run = validator(run) + return run +``` + +### 3) Sequential을 유지하면서 동적 파라미터가 필요하면 closure/partial + +이건 “필수”가 아니라, **steps 배열을 유지하고 싶은 사람을 위한 옵션**입니다. + +--- + +## Hook(Tracing)은 어디서 확인하나? + +Flow도 hook을 받을 수 있습니다. + +```python +from lang2sql.core.hooks import MemoryHook + +hook = MemoryHook() +flow = SequentialFlow(steps=[...], hook=hook) + +out = flow.run_query("지난달 매출") + +for e in hook.events: + print(e.name, e.phase, e.component, e.duration_ms, e.error) +``` + +운영에서는 `MemoryHook` 대신 로그/OTel/필터링 훅을 사용합니다. +관측성 제어는 **hook 구현체에서** 담당하고, Flow 코드는 비즈니스 로직에 집중하도록 분리합니다. + +--- + +## (관련 개념) BaseFlow와 BaseComponent의 관계 + +* `BaseFlow`는 “어떻게 실행할지(제어/조립)”를 담당합니다. +* `BaseComponent`는 “한 단계에서 무엇을 할지(작업 단위)”를 담당합니다. + +일반적으로: + +* **Flow는 여러 Component를 호출**하거나, +* **SequentialFlow는 Component/함수를 steps로 받아 순차 실행**합니다. + +즉, **Flow가 상위 레벨 오케스트레이션**, Component가 **재사용 가능한 부품**입니다. + +--- + +## FAQ + +### Q. BaseFlow가 필수인가? + +A. Flow라는 개념은 사실상 필요하지만, **모든 사용자가 BaseFlow를 직접 상속할 필요는 없습니다.** + +* 초급: `SequentialFlow`만 사용 +* 고급: `BaseFlow`를 상속해서 제어를 직접 작성 + +### Q. Flow의 반환 타입은? + +A. `run()`은 **반드시 `RunContext`를 반환**하는 것을 권장합니다. +(합성/디버깅/타입 안정성 측면에서 이득이 큽니다.) + +--- + diff --git a/docs/Core_concept_ko.md b/docs/Core_concept_ko.md new file mode 100644 index 0000000..676dba7 --- /dev/null +++ b/docs/Core_concept_ko.md @@ -0,0 +1,137 @@ +# Core Concepts + +Lang2SQL은 “그래프 엔진/DSL”을 강제하지 않고, **순수 Python 코드로 파이프라인을 제어**하는 define-by-run 철학을 따릅니다. +대신, 파이프라인이 커져도 연결이 무너지지 않도록 **`RunContext`라는 최소 상태 컨테이너**를 중심으로 설계합니다. + +--- + +## 1) Define-by-run: 제어는 Python으로 + +Lang2SQL에서 파이프라인 제어는 프레임워크가 아니라 **사용자 코드가 가집니다.** + +* 분기: `if / match` +* 반복/재시도: `for / while` +* 조건부 실행: policy 기반 action +* 서브플로우: flow를 step처럼 호출 + +예시: + +```python +def ret(run): ... +def ret_val(run): ... +def policy(metrics): ... +def gen(run): ... + +run = RunContext("q") + +while True: + run = ret(run) + metrics = ret_val(run) # ✅ run 몰라도 되는 순수 함수 가능 + action = policy(metrics) # ✅ run 몰라도 되는 순수 함수 가능 + if action == "retry": + continue + break + +run = gen(run) +``` + +**핵심:** Lang2SQL은 위 패턴을 “프레임워크 문법”으로 바꾸지 않습니다. +그냥 Python으로 쓰되, 파이프라인 간 상태 전달을 안정적으로 하기 위해 `RunContext`를 사용합니다. + +--- + +## 2) 왜 RunContext가 필요한가? + +Text2SQL 파이프라인은 현실적으로 단계가 늘어납니다. + +* retriever 1개가 아니라 10개, 100개가 될 수 있음 +* 중간 산출물(선택된 테이블, 컨텍스트, 후보 SQL, 검증 결과, 점수/메트릭)이 늘어남 +* loop/branch가 들어가면서 “어떤 단계에서 무엇이 생성되었는지” 추적이 어려워짐 + +이 상황에서 단계마다 함수 시그니처를 계속 바꾸면: + +* `retriever(query, catalog) -> selected` +* `builder(query, selected) -> context` +* `generator(query, context) -> sql` +* `validator(sql) -> validation` + +처럼 보이지만, 실제로는 **중간에 필요한 값이 계속 추가**되어 시그니처가 폭발합니다. + +### RunContext는 “큰 그래프에서 연결 안정성”을 만든다 + +Lang2SQL은 각 step의 I/O를 **`RunContext -> RunContext`**로 고정합니다. + +* step이 늘어나도 “연결 방식”이 바뀌지 않음 +* 어떤 단계가 어떤 값을 추가해도, 다음 단계는 필요한 값을 `run`에서 읽으면 됨 +* loop/branch/서브플로우에서도 동일한 규약 유지 + +그래서 문서에서 아래처럼 “개념적 함수형”으로 설명하더라도: + +* retriever: (query, catalog) -> selected +* builder: (query, selected) -> context +* generator: (query, context) -> sql +* validator: (sql) -> validation + +실제 구현은 **RunContext 내부 필드의 Read/Write 규약**으로 통일됩니다. + +예: + +* retriever: `run.query`, `run.schema_catalog` 읽고 → `run.schema_selected` 씀 +* builder: `run.query`, `run.schema_selected` 읽고 → `run.schema_context` 씀 +* generator: `run.query`, `run.schema_context` 읽고 → `run.sql` 씀 +* validator: `run.sql` 읽고 → `run.validation` 씀 + +### “쿼리가 바뀌면?”도 제어 가능 + +`RunContext`는 mutable state carrier이므로, 루프 중간에 쿼리를 업데이트해도 됩니다. + +```python +run.query = rewritten_query +run = ret(run) # 업데이트된 query로 재검색 +``` + +--- + +## 3) `run(runcontext)` vs `run_query(query)` + +두 API의 관계는 단순합니다. + +### `run(run: RunContext) -> RunContext` + +* **명시적 엔트리포인트** +* 고급 제어(루프/분기/정책)나 서브플로우 합성에서 자연스럽습니다. + +```python +run = RunContext(query="지난달 매출") +out = flow.run(run) +``` + +### `run_query(query: str) -> RunContext` + +* **편의(sugar) 엔트리포인트** +* 초급/데모/퀵스타트에서 `RunContext`를 몰라도 실행 가능하게 합니다. +* 내부적으로는 보통 아래와 동치입니다: + +```python +out = flow.run(RunContext(query=query)) +``` + +즉, + +```python +out1 = flow.run_query("지난달 매출") +out2 = flow.run(RunContext(query="지난달 매출")) +``` + +은 **같은 기능**을 제공합니다. 차이는 **입력 형태(문자열 vs RunContext)** 뿐입니다. + +--- + +## 권장 규약 요약 + +* **제어는 Python으로 한다** (define-by-run) +* **상태 전달은 RunContext로 고정한다** (`RunContext -> RunContext`) +* `run_query()`는 **초급/데모용 편의 API**, `run()`은 **명시적/고급 제어용 API** +* policy/eval처럼 RunContext가 필요 없는 로직은 **순수 함수로 둬도 된다** (필요하면 run에서 읽거나 metadata로 남기는 건 선택) + +--- diff --git a/docs/Hook_and_exception_ko.md b/docs/Hook_and_exception_ko.md new file mode 100644 index 0000000..c5764e5 --- /dev/null +++ b/docs/Hook_and_exception_ko.md @@ -0,0 +1,311 @@ +# Hooks (Tracing) + +Lang2SQL의 hooks 시스템은 **그래프 엔진 없이도 관측성(observability)을 제공**하기 위한 최소 레이어입니다. +Flow/Component 실행 과정에서 이벤트를 발행하고, 사용자는 hook 구현체로 이를 수집/출력/전송할 수 있습니다. + +핵심 컨셉은 단 하나입니다: + +> **“실행 중 무슨 일이 일어났는지(Event)를 hook이 받는다.”** + +--- + +## Event + +`Event`는 Flow/Component 실행 중 발생한 “관측 단위”입니다. + +```py +@dataclass +class Event: + name: str # e.g., "component.run" / "flow.run" + component: str # e.g., "KeywordTableRetriever" / "SequentialFlow" + phase: Literal["start", "end", "error"] + ts: float # unix timestamp + duration_ms: Optional[float] = None + + input_summary: Optional[str] = None + output_summary: Optional[str] = None + error: Optional[str] = None + + data: dict[str, Any] = field(default_factory=dict) +``` + +### 필드 의미 + +* `name` + + * 이벤트 종류를 나타내는 문자열 + * 예: `"component.run"`, `"flow.run"` +* `component` + + * 이벤트를 발생시킨 실행 단위 이름 + * 예: `"KeywordTableRetriever"`, `"SequentialFlow"` +* `phase` + + * `"start" | "end" | "error"` +* `ts` + + * 이벤트 발생 시간(Unix timestamp) +* `duration_ms` + + * `end/error`에서만 주로 채움(실행 시간) +* `input_summary`, `output_summary` + + * 디버깅을 위한 “사람이 읽기 쉬운” 요약 문자열 +* `error` + + * 실패 시 오류 요약 문자열 +* `data` + + * UI/필터링/테스트/추가 메타를 위한 구조화 payload + * 기본은 빈 dict이며, 필요할 때만 채우는 것을 권장합니다. + +--- + +## TraceHook + +`TraceHook`은 이벤트를 받는 인터페이스입니다. + +```py +class TraceHook(Protocol): + def on_event(self, event: Event) -> None: ... +``` + +* Lang2SQL의 Flow/Component는 실행 시점에 `hook.on_event(Event(...))` 형태로 이벤트를 발행합니다. +* hook은 **옵션**이며, 없으면 `NullHook`이 사용됩니다. + +--- + +## 기본 Hook 구현체 + +### NullHook + +```py +class NullHook: + def on_event(self, event: Event) -> None: + return +``` + +* 기본값 +* 아무 것도 하지 않습니다. +* hook 비용을 없애고 싶을 때 항상 안전한 기본 구현입니다. + +### MemoryHook + +```py +class MemoryHook: + def __init__(self) -> None: + self.events: list[Event] = [] + + def on_event(self, event: Event) -> None: + self.events.append(event) + + def clear(self) -> None: + self.events.clear() + + def snapshot(self) -> list[Event]: + return list(self.events) +``` + +* 이벤트를 메모리에 누적합니다. +* 테스트/디버깅에 가장 유용합니다. + +#### MemoryHook 사용 예시 + +```py +from lang2sql.core.hooks import MemoryHook +from lang2sql.flows.baseline import BaselineFlow + +hook = MemoryHook() +flow = BaselineFlow(steps=[...], hook=hook) + +out = flow.run_query("지난달 매출") + +for e in hook.events: + print(e.name, e.phase, e.component, e.duration_ms, e.error) +``` + +#### clear()를 유저가 직접 호출해야 하나? + +* 보통은 **테스트에서만** `clear()`가 필요합니다. (케이스 간 이벤트 섞임 방지) +* 일반 사용자는 보통 “요청 1회 → hook 1개 생성” 패턴으로 충분합니다. + +예: + +```py +hook = MemoryHook() +out = flow.run_query("q") # 여기서만 쓰고 끝 +events = hook.snapshot() +``` + +--- + +## 유틸 함수 + +### now() + +```py +def now() -> float: + return time.time() +``` + +* timestamp 생성에 사용됩니다. + +### ms() + +```py +def ms(start: float, end: float) -> float: + return (end - start) * 1000.0 +``` + +* duration(ms) 계산에 사용됩니다. + +### summarize() + +```py +def summarize(x: Any, max_len: int = 240) -> str: + ... +``` + +* repr(x)를 기반으로 요약 문자열을 만들고 길이를 제한합니다. +* 이벤트의 `input_summary/output_summary`에 사용됩니다. + +--- + +## 운영(Production)에서는 어떻게 쓰나? + +MemoryHook은 테스트용입니다. 운영에서는 보통 다음 형태로 확장합니다. + +* `LoggingHook`: JSON 로그로 남기기 +* `OTelHook`: OpenTelemetry span으로 전송 +* `FilteringHook`: 특정 component만 샘플링/필터링 + +관측성 제어는 **hook 구현체에서** 하고, Flow/Component 로직은 비즈니스에 집중하는 것이 기본 철학입니다. + +--- + +# Exceptions + +Lang2SQL 예외 시스템은 두 목표를 가집니다. + +1. **도메인 에러는 도메인 타입으로 유지**한다. +2. 외부/일반 예외는 “어디서 터졌는지”가 보이도록 **표준 래핑**한다. + +--- + +## Lang2SQLError (Base) + +```py +class Lang2SQLError(Exception): + """Base error for lang2sql.""" +``` + +* Lang2SQL에서 발생하는 모든 도메인 예외의 베이스입니다. +* `BaseComponent` / `BaseFlow`는 일반적으로 **Lang2SQLError는 그대로 다시 raise**합니다. + +--- + +## IntegrationMissingError + +```py +class IntegrationMissingError(Lang2SQLError): + def __init__(self, integration: str, extra: str | None = None, hint: str | None = None): + ... +``` + +### 언제 발생? + +* 선택적 의존성(optional integration)이 필요한데 설치되어 있지 않을 때 + +예: + +* `faiss` retriever를 쓰는데 `faiss`가 설치되어 있지 않음 + +### 메시지 특징 + +* `extra`가 있으면 설치 힌트를 포함합니다. + +예 메시지: + +* `Missing optional integration: faiss. Install with: pip install 'lang2sql[faiss]'` + +--- + +## ValidationError + +```py +class ValidationError(Lang2SQLError): + pass +``` + +### 언제 발생? + +* SQL 검증 실패, 정책상 금지 쿼리, 스키마 불일치 등 +* “유저 입력/생성 결과가 유효하지 않다”에 해당하는 에러를 담는 대표 도메인 예외 + +--- + +## ContractError + +```py +class ContractError(Lang2SQLError): + """Raised when a component violates a required call/return contract.""" + pass +``` + +### 언제 발생? + +* Lang2SQL이 요구하는 호출/반환 계약을 위반했을 때 +* 예: `RunContext -> RunContext` 계약인데 `None` 또는 `int`를 반환 + +이 에러는 “사용자 코드 버그를 빨리 발견(fail-fast)”하기 위한 타입입니다. + +--- + +## ComponentError + +```py +class ComponentError(Lang2SQLError): + def __init__(self, component: str, message: str, *, cause: Exception | None = None): + self.component = component + self.cause = cause + super().__init__(f"[{component}] {message}") +``` + +### 목적 + +* “일반 예외(ValueError, KeyError 등)”를 도메인 레이어로 끌어올 때 사용합니다. +* 어떤 컴포넌트에서 터졌는지 식별 가능하게 만듭니다. + +### cause + +* 원본 예외를 보존합니다. +* 테스트/디버깅에서 error chain을 확인할 수 있습니다. + +--- + +## 예외가 Flow/Component에서 어떻게 처리되나? + +(현재 BaseComponent 설계 기준) + +* `Lang2SQLError` 계열 + + * 그대로 이벤트에 기록하고 그대로 raise +* 그 외 모든 예외 + + * 이벤트에 기록하고 `ComponentError(..., cause=e)`로 래핑하여 raise + +즉: + +* **도메인 예외는 “정상적인 실패”로 취급** +* **일반 예외는 “버그/예상 밖 실패”로 표준화** + +--- + +## 권장 사용 가이드 + +* “사용자 입력/정책/검증 실패”는 `ValidationError` +* “의존성 설치 문제”는 `IntegrationMissingError` +* “계약 위반(반환 타입/호출 규약)”은 `ContractError` +* “외부 라이브러리/예상 밖 예외”는 `ComponentError`로 래핑되어 올라오는 것을 기본으로 합니다. + +--- diff --git a/docs/RunContext_ko.md b/docs/RunContext_ko.md new file mode 100644 index 0000000..3317216 --- /dev/null +++ b/docs/RunContext_ko.md @@ -0,0 +1,113 @@ +## RunContext + +`RunContext`는 define-by-run 파이프라인에서 **상태(state)를 운반하는 최소 State Carrier**입니다. +컴포넌트는 기본적으로 `RunContext -> RunContext` 계약을 따르며, 필요한 값을 읽고/쓰면서 파이프라인을 구성합니다. + +### 설계 원칙 + +* **최소 루트 필드 5개만 고정**: `inputs / artifacts / outputs / error / metadata` +* 루트는 전부 `dict` 기반(스키마 락인 방지) +* 자주 쓰는 값은 **alias 프로퍼티**로 제공하여 UX 개선 (`run.query`, `run.sql` 등) + +--- + +## 데이터 구조 트리 + +아래는 `RunContext`가 담는 데이터 구조를 “트리 형태”로 나타낸 것입니다. + +``` +RunContext +├─ inputs: dict +│ └─ "query": str +│ +├─ artifacts: dict +│ └─ "schema": dict +│ ├─ "catalog": Any +│ │ (예: list[TableSchema] | provider | None) +│ ├─ "selected": Any +│ │ (예: list[TableCandidate] | None) +│ └─ "context": str +│ (prompt에 넣을 스키마 컨텍스트) +│ +├─ outputs: dict +│ ├─ "sql": str +│ └─ "validation": Any +│ +├─ error: dict | None +│ └─ (구조화된 에러 정보. 형식은 프로젝트 정책에 따라 확장 가능) +│ +└─ metadata: dict + └─ (로그/추적/히스토리/실험용 값. 표준 스키마 강제 없음) + 예) + ├─ "events": list[Event] + ├─ "sql_drafts": list[str] + ├─ "attempt": int + └─ ... +``` + +--- + +## Root fields (고정 5개) + +* `inputs: dict[str, Any]` — 사용자 입력 +* `artifacts: dict[str, Any]` — 중간 산출물 +* `outputs: dict[str, Any]` — 최종 산출물 +* `error: Optional[dict[str, Any]]` — 구조화된 에러(선택) +* `metadata: dict[str, Any]` — 로그/추적/히스토리(선택) + +--- + +## 권장 키 컨벤션 (Minimal Standard) + +### inputs + +* `inputs["query"]`: 자연어 질의 + +### artifacts["schema"] + +* `catalog`: 스키마 카탈로그(테이블/컬럼 목록 등) +* `selected`: 선택된 테이블 후보 +* `context`: 프롬프트에 들어갈 스키마 컨텍스트 문자열 + +### outputs + +* `outputs["sql"]`: 최종 SQL +* `outputs["validation"]`: 검증 결과(구조는 구현체 자유) + +--- + +## Alias (Beginner-friendly API) + +키 문자열 접근을 줄이기 위해 alias를 제공합니다. + +* `run.query` ↔ `inputs["query"]` +* `run.sql` ↔ `outputs["sql"]` +* `run.validation` ↔ `outputs["validation"]` + +스키마 관련 alias: + +* `run.schema` ↔ `artifacts["schema"]` *(항상 dict로 보정)* +* `run.schema_catalog` ↔ `run.schema["catalog"]` +* `run.schema_selected` ↔ `run.schema["selected"]` +* `run.schema_context` ↔ `run.schema["context"]` + +--- + +## 파이프라인 예시 (Text2SQL) + +개념: + +* retriever: `(query, catalog) -> selected` +* builder: `(query, selected) -> context` +* generator: `(query, context) -> sql` +* validator: `(sql) -> validation` + +RunContext에서의 읽기/쓰기: + +* retriever: `run.query`, `run.schema_catalog` 읽고 → `run.schema_selected` 작성 +* builder: `run.query`, `run.schema_selected` 읽고 → `run.schema_context` 작성 +* generator: `run.query`, `run.schema_context` 읽고 → `run.sql` 작성 +* validator: `run.sql` 읽고 → `run.validation` 작성 + +--- + diff --git a/guideline.md b/guideline.md new file mode 100644 index 0000000..ad6fa31 --- /dev/null +++ b/guideline.md @@ -0,0 +1,99 @@ +lang2sql/ +├── __init__.py # public re-exports (Flow, Component, flows, components) +├── _version.py +│ +├── core/ # ✅ 외부 의존성 0% (절대 import 금지) +│ ├── __init__.py +│ ├── base.py # BaseComponent, BaseFlow (define-by-run 핵심 뼈대) +│ ├── types.py # 최소 타입(dataclass/typing): Table/Column/Result 등(강제 X, 참고용) +│ ├── exceptions.py # Lang2SQLError, ComponentError, IntegrationMissingError... +│ ├── hooks.py # TraceHook, Event, NullHook (관측/로깅 인터페이스) +│ ├── context.py # (선택) RunContext: dict wrapper (권장일 뿐 강제 X) +│ ├── registry.py # (선택) PluginRegistry + entry_points 로더 (retriever 100개 대비) +│ └── utils.py # 순수 유틸 +│ +├── components/ # ✅ Lang2SQL이 제공하는 “부품 상자” +│ ├── __init__.py +│ ├── retrieval/ +│ │ ├── __init__.py +│ │ ├── keyword.py # vstore 없는 기본 retriever (즉시 사용 가능) +│ │ ├── vector.py # vector retriever (실제 store는 integrations를 통해 주입) +│ │ ├── catalog.py # schema/catalog 기반 retriever +│ │ └── normalize.py # (선택) 후보 정규화 유틸 (표준 강제가 아니라 “편의”) +│ │ +│ ├── context/ +│ │ ├── __init__.py +│ │ ├── builder.py # build_context (토큰 예산/압축 정책 포함 가능) +│ │ └── budget.py # token budget helpers (의존성 없는 버전) +│ │ +│ ├── generation/ +│ │ ├── __init__.py +│ │ ├── sql.py # SQL 생성 컴포넌트(LLM은 integrations llm client를 주입) +│ │ └── prompts.py # 프롬프트 템플릿/formatters +│ │ +│ ├── validation/ +│ │ ├── __init__.py +│ │ ├── static.py # 문법/금지쿼리/스키마 참조 등 “실행 없는” 검증 +│ │ └── execution.py # (옵션) 실제 DB 실행 검증(연동은 integrations) +│ │ +│ └── adapters.py # ⭐ 핵심: 외부 retriever/llm 객체를 callable로 감싸는 as_callable +│ +├── flows/ # ✅ “완제품/프리셋” (define-by-run 클래스/함수) +│ ├── __init__.py +│ ├── baseline.py # BaselineFlow: retrieve -> context -> generate -> validate +│ ├── agentic.py # AgenticFlow: while/for 루프 포함 (사용자 override 용이) +│ └── examples.py # weird flow 예시(문서/테스트 겸) +│ +├── integrations/ # ✅ 외부 의존성 구현 (extras) +│ ├── __init__.py +│ ├── llm/ +│ │ ├── __init__.py +│ │ ├── base.py # thin wrapper 인터페이스(여긴 외부 의존성 있을 수 있음) +│ │ ├── openai_.py # openai extra +│ │ ├── anthropic_.py # anthropic extra +│ │ └── upstage_.py # upstage extra +│ │ +│ ├── vector_store/ +│ │ ├── __init__.py +│ │ ├── base.py # VectorStorePort (여긴 integrations 계층) +│ │ ├── faiss_.py # faiss extra +│ │ ├── pgvector_.py # pgvector extra +│ │ └── pinecone_.py # pinecone extra +│ │ +│ ├── metadata/ +│ │ ├── __init__.py +│ │ ├── base.py # SchemaProviderPort +│ │ ├── sqlalchemy_.py # sqlalchemy extra +│ │ └── datahub_.py # datahub extra +│ │ +│ └── langgraph/ # (선택) "브릿지"만. 메인 설계는 아님. +│ ├── __init__.py +│ └── bridge.py # Flow를 langgraph node로 감싸거나, component를 노드로 변환하는 정도 +│ +├── presets/ # ✅ 초보자 UX: “원클릭 생성기” +│ ├── __init__.py +│ ├── factory.py # AutoText2SQL(...) -> BaselineFlow/AgenticFlow + 기본 컴포넌트 조립 +│ └── defaults.py # 기본 조합 정책(vstore 없이도 동작하는 기본값 포함) +│ +├── cli/ +│ ├── __init__.py +│ └── main.py # lang2sql query/run 등 +│ +├── app/ +│ ├── __init__.py +│ └── streamlit/ +│ ├── __init__.py +│ └── main.py +│ +├── tests/ +│ ├── test_baseline.py +│ ├── test_agentic.py +│ ├── test_adapters.py +│ └── test_registry.py +│ +└── docs/ + ├── philosophy.md + ├── quickstart.md + ├── customizing.md + ├── writing-your-own-flow.md + └── plugins.md \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3147fd9..809924c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ path = "version.py" include = [ "version.py", "prompt/*.md", + "src/lang2sql/**", ] [tool.hatch.build.targets.wheel] @@ -72,6 +73,7 @@ packages = [ "infra", "prompt", "utils", + "src/lang2sql" ] [tool.uv] diff --git a/src/lang2sql/__init__.py b/src/lang2sql/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lang2sql/core/__init__.py b/src/lang2sql/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lang2sql/core/base.py b/src/lang2sql/core/base.py new file mode 100644 index 0000000..34e6de0 --- /dev/null +++ b/src/lang2sql/core/base.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Optional + +from .exceptions import ContractError +from .context import RunContext +from .exceptions import ComponentError, Lang2SQLError +from .hooks import Event, NullHook, TraceHook, ms, now, summarize + + +class BaseComponent(ABC): + """ + Base class for all components. + + Design goals: + - Components are plain callables (define-by-run friendly). + - No enforced global state schema. + - Hooks provide observability without requiring a graph engine. + """ + + def __init__( + self, name: Optional[str] = None, hook: Optional[TraceHook] = None + ) -> None: + self.name: str = name or self.__class__.__name__ + self.hook: TraceHook = hook or NullHook() + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + t0 = now() + self.hook.on_event( + Event( + name="component.run", + component=self.name, + phase="start", + ts=t0, + input_summary=f"args={summarize(args)} kwargs={summarize(kwargs)}", + ) + ) + + try: + out = self.run(*args, **kwargs) + + if ( + args + and isinstance(args[0], RunContext) + and not isinstance(out, RunContext) + ): + got = "None" if out is None else type(out).__name__ + raise ContractError( + f"{self.name} must return RunContext (got {got}). Did you forget `return run`?" + ) + + t1 = now() + self.hook.on_event( + Event( + name="component.run", + component=self.name, + phase="end", + ts=t1, + duration_ms=ms(t0, t1), + output_summary=summarize(out), + ) + ) + return out + + except Lang2SQLError as e: + # Preserve domain-level errors (IntegrationMissingError, ValidationError, etc.). + t1 = now() + self.hook.on_event( + Event( + name="component.run", + component=self.name, + phase="error", + ts=t1, + duration_ms=ms(t0, t1), + error=f"{type(e).__name__}: {e}", + ) + ) + raise + + except Exception as e: + # Wrap non-domain errors into ComponentError. + t1 = now() + self.hook.on_event( + Event( + name="component.run", + component=self.name, + phase="error", + ts=t1, + duration_ms=ms(t0, t1), + error=f"{type(e).__name__}: {e}", + ) + ) + raise ComponentError( + self.name, + f"component failed ({type(e).__name__}: {e})", + cause=e, + ) from e + + @abstractmethod + def run(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + +class BaseFlow(ABC): + """ + Base class for flows. + + Define-by-run: + - Users write control-flow in pure Python (if/for/while). + - We provide parts + presets, not a graph engine. + """ + + def __init__( + self, name: Optional[str] = None, hook: Optional[TraceHook] = None + ) -> None: + self.name: str = name or self.__class__.__name__ + self.hook: TraceHook = hook or NullHook() + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + t0 = now() + self.hook.on_event( + Event(name="flow.run", component=self.name, phase="start", ts=t0) + ) + + try: + out = self.run(*args, **kwargs) + t1 = now() + self.hook.on_event( + Event( + name="flow.run", + component=self.name, + phase="end", + ts=t1, + duration_ms=ms(t0, t1), + ) + ) + return out + + except Lang2SQLError as e: + t1 = now() + self.hook.on_event( + Event( + name="flow.run", + component=self.name, + phase="error", + ts=t1, + duration_ms=ms(t0, t1), + error=f"{type(e).__name__}: {e}", + ) + ) + raise + + except Exception as e: + t1 = now() + self.hook.on_event( + Event( + name="flow.run", + component=self.name, + phase="error", + ts=t1, + duration_ms=ms(t0, t1), + error=f"{type(e).__name__}: {e}", + ) + ) + raise + + @abstractmethod + def run(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + def run_query(self, query: str) -> RunContext: + """ + Convenience entrypoint. + + Creates a RunContext(query=...) and runs the flow. + Intended for demos / quickstart. + + Args: + query: Natural language question. + + Returns: + RunContext after running this flow. + """ + out = self.run(RunContext(query=query)) + if not isinstance(out, RunContext): + got = "None" if out is None else type(out).__name__ + raise TypeError( + f"{self.name}.run(run: RunContext) must return RunContext, got {got}" + ) + return out diff --git a/src/lang2sql/core/context.py b/src/lang2sql/core/context.py new file mode 100644 index 0000000..7d3a0f5 --- /dev/null +++ b/src/lang2sql/core/context.py @@ -0,0 +1,182 @@ +# lang2sql/core/context.py +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional +from collections.abc import MutableMapping + + +@dataclass(init=False) +class RunContext: + """ + A minimal state carrier for define-by-run pipelines. + + Internal storage is generic: + - inputs: user inputs (e.g., query) + - artifacts: intermediate artifacts (e.g., schema candidates, prompt context) + - outputs: final outputs (e.g., sql, validation) + - error: structured error information (optional) + - metadata: logs/traces/history (optional) + + Public UX can be domain-friendly via alias properties: + - .query, .schema, .sql, .validation, etc. + """ + + inputs: dict[str, Any] = field(default_factory=dict) + artifacts: dict[str, Any] = field(default_factory=dict) + outputs: dict[str, Any] = field(default_factory=dict) + error: Optional[dict[str, Any]] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __init__(self, query: Optional[str] = None, **kwargs: Any) -> None: + # Keep storage generic and always initialized. + self.inputs = {} + self.artifacts = {} + self.outputs = {} + self.error = None + self.metadata = {} + + if query is not None: + self.inputs["query"] = query + + # Allow lightweight initialization: RunContext(foo=..., bar=...) + # Store unknown fields in metadata to avoid schema lock-in. + if kwargs: + self.metadata.update(kwargs) + + # ------------------------- + # Domain-friendly aliases + # ------------------------- + + @property + def query(self) -> str: + """Return the user question/query.""" + v = self.inputs.get("query", "") + if v is None: + return "" + if isinstance(v, str): + return v + # Be forgiving: keep pipeline running, but avoid crashing on accidental types. + return str(v) + + @query.setter + def query(self, value: str) -> None: + if not isinstance(value, str): + raise TypeError(f"RunContext.query must be str, got {type(value).__name__}") + self.inputs["query"] = value + + @property + def schema(self) -> MutableMapping[str, Any]: + """ + Return the schema artifact mapping. + + Typical keys (convention): + - catalog: full schema catalog (list/provider) + - selected: top-k table candidates + - context: final context text used for prompting + """ + v = self.artifacts.get("schema") + if v is None: + v = {} + self.artifacts["schema"] = v + return v + + if isinstance(v, MutableMapping): + return v + + # If someone wrote a non-mapping value, replace it to keep conventions stable. + v = {} + self.artifacts["schema"] = v + return v + + @property + def sql(self) -> str: + """Return the final SQL string.""" + v = self.outputs.get("sql", "") + if v is None: + return "" + if isinstance(v, str): + return v + return str(v) + + @sql.setter + def sql(self, value: str) -> None: + if not isinstance(value, str): + raise TypeError(f"RunContext.sql must be str, got {type(value).__name__}") + self.outputs["sql"] = value + + # sql changes invalidate validation (validation is derived from sql) + # self.outputs.pop("validation", None) + + @property + def validation(self) -> Any: + """Return validation result object, if present.""" + return self.outputs.get("validation") + + @validation.setter + def validation(self, value: Any) -> None: + self.outputs["validation"] = value + + # Optional convenience aliases (recommended for discoverability) + + @property + def schema_catalog(self) -> Any: + """Alias for schema['catalog'].""" + return self.schema.get("catalog") + + @schema_catalog.setter + def schema_catalog(self, value: Any) -> None: + self.schema["catalog"] = value + + @property + def schema_selected(self) -> Any: + """Alias for schema['selected'].""" + return self.schema.get("selected") + + @schema_selected.setter + def schema_selected(self, value: Any) -> None: + self.schema["selected"] = value + + @property + def schema_context(self) -> str: + """Alias for schema['context'].""" + v = self.schema.get("context", "") + if v is None: + return "" + if isinstance(v, str): + return v + return str(v) + + @schema_context.setter + def schema_context(self, value: str) -> None: + if not isinstance(value, str): + raise TypeError( + f"RunContext.schema_context must be str, got {type(value).__name__}" + ) + self.schema["context"] = value + + # ------------------------- + # Small utilities + # ------------------------- + + def push_meta(self, key: str, value: Any) -> None: + """ + Append a value into metadata[key] list. + + Example: + run.push_meta("sql_drafts", sql) + run.push_meta("events", event) + """ + arr = self.metadata.setdefault(key, []) + if not isinstance(arr, list): + raise TypeError(f"metadata['{key}'] exists but is not a list") + arr.append(value) + + def get_meta_list(self, key: str) -> list[Any]: + """Return metadata[key] as a list (empty list if missing).""" + v = self.metadata.get(key) + if v is None: + return [] + if isinstance(v, list): + return v + return [v] diff --git a/src/lang2sql/core/exceptions.py b/src/lang2sql/core/exceptions.py new file mode 100644 index 0000000..c4e8556 --- /dev/null +++ b/src/lang2sql/core/exceptions.py @@ -0,0 +1,39 @@ +# lang2sql/core/exceptions.py +from __future__ import annotations + + +class Lang2SQLError(Exception): + """Base error for lang2sql.""" + + +class IntegrationMissingError(Lang2SQLError): + def __init__( + self, integration: str, extra: str | None = None, hint: str | None = None + ): + self.integration = integration + self.extra = extra + self.hint = hint + + msg = f"Missing optional integration: {integration}." + if extra: + msg += f" Install with: pip install 'lang2sql[{extra}]'" + if hint: + msg += f" ({hint})" + super().__init__(msg) + + +class ComponentError(Lang2SQLError): + def __init__(self, component: str, message: str, *, cause: Exception | None = None): + self.component = component + self.cause = cause + super().__init__(f"[{component}] {message}") + + +class ValidationError(Lang2SQLError): + pass + + +class ContractError(Lang2SQLError): + """Raised when a component violates a required call/return contract.""" + + pass diff --git a/src/lang2sql/core/hooks.py b/src/lang2sql/core/hooks.py new file mode 100644 index 0000000..8f2729b --- /dev/null +++ b/src/lang2sql/core/hooks.py @@ -0,0 +1,64 @@ +# lang2sql/core/hooks.py +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, Optional, Literal +import time + + +@dataclass +class Event: + name: str # e.g., "component.run" + component: str # e.g., "KeywordTableRetriever" + phase: Literal["start", "end", "error"] + ts: float + duration_ms: Optional[float] = None + + # human-friendly summaries (debug only) + input_summary: Optional[str] = None + output_summary: Optional[str] = None + error: Optional[str] = None + + # structured payload (for UI / filtering / tests) + data: dict[str, Any] = field(default_factory=dict) + + +class TraceHook(Protocol): + def on_event(self, event: Event) -> None: ... + + +class NullHook: + def on_event(self, event: Event) -> None: + return + + +class MemoryHook: + def __init__(self) -> None: + self.events: list[Event] = [] + + def on_event(self, event: Event) -> None: + self.events.append(event) + + def clear(self) -> None: + self.events.clear() + + def snapshot(self) -> list[Event]: + return list(self.events) + + +def now() -> float: + return time.time() + + +def ms(start: float, end: float) -> float: + return (end - start) * 1000.0 + + +def summarize(x: Any, max_len: int = 240) -> str: + try: + s = repr(x) + except Exception: + s = f"" + if len(s) > max_len: + s = s[: max_len - 3] + "..." + return s diff --git a/src/lang2sql/flows/baseline.py b/src/lang2sql/flows/baseline.py new file mode 100644 index 0000000..b088e7d --- /dev/null +++ b/src/lang2sql/flows/baseline.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from typing import Iterable, Protocol, Sequence + +from ..core.base import BaseFlow +from ..core.context import RunContext +from ..core.exceptions import ContractError + + +class RunComponent(Protocol): + """ + Protocol for a pipeline component. + + A component is a callable that takes a RunContext and must return a RunContext. + This enforces a strict "RunContext in -> RunContext out" contract. + + Args: + run: The current RunContext. + + Returns: + A RunContext instance (usually the same object, mutated in-place). + """ + + def __call__(self, run: RunContext) -> RunContext: ... + + +class SequentialFlow(BaseFlow): + """ + A minimal sequential pipeline runner (define-by-run style). + + This flow runs `steps` in order. Each step must follow the contract: + RunContext -> RunContext + + Notes: + - Steps may mutate `run` in-place and still must return `run`. + - Returning None or a non-RunContext value is treated as a contract bug and fails fast. + + Args: + steps: Ordered sequence of pipeline components. + name: Optional name override for tracing/logging. + hook: Optional TraceHook. If not provided, a NullHook is used by BaseFlow. + + Returns: + The final RunContext after running all steps. + """ + + def __init__( + self, + *, + steps: Sequence[RunComponent], + name: str | None = None, + hook=None, + ) -> None: + """ + Initialize the flow with an ordered list of steps. + + Args: + steps: Ordered sequence of pipeline components. Must be non-empty. + name: Optional name override. + hook: Optional TraceHook used by BaseFlow for flow-level events. + + Raises: + ValueError: If `steps` is empty. + """ + super().__init__(name=name or "SequentialFlow", hook=hook) + if not steps: + raise ValueError("SequentialFlow requires at least one step.") + self.steps: list[RunComponent] = list(steps) + + @staticmethod + def _apply(step: RunComponent, run: RunContext) -> RunContext: + """ + Apply a single step with strict contract validation. + + Args: + step: A pipeline component (callable). + run: The current RunContext. + + Returns: + The RunContext returned by the step. + + Raises: + ContractError: If the step returns None or a non-RunContext value. + """ + out = step(run) + + if isinstance(out, RunContext): + return out + + got = "None" if out is None else type(out).__name__ + raise ContractError( + f"Component must return RunContext (got {got}). Did you forget `return run`?" + ) + + def _run_steps( + self, run: RunContext, steps: Iterable[RunComponent] | None = None + ) -> RunContext: + """ + Run an iterable of steps sequentially. + + Args: + run: The initial RunContext. + steps: Optional override iterable of steps. If None, uses `self.steps`. + + Returns: + The final RunContext after applying all steps. + """ + it = self.steps if steps is None else steps + for step in it: + run = self._apply(step, run) + return run + + def run(self, run: RunContext) -> RunContext: + """ + Execute the flow on the given RunContext. + + Args: + run: The initial RunContext. + + Returns: + The final RunContext after running all configured steps. + """ + return self._run_steps(run) + + def run_query(self, query: str) -> RunContext: + """ + Beginner-friendly sugar API. + + Args: + query: Natural language question / user query. + + Returns: + A RunContext initialized with `query` and processed by the flow. + """ + return super().run_query(query) + + +# Backward-compatible alias (optional). +# Keeping this alias means existing imports `from lang2sql.flows... import BaselineFlow` +# continue to work without changes. +BaselineFlow = SequentialFlow diff --git a/tests/test_core_base.py b/tests/test_core_base.py new file mode 100644 index 0000000..cba775a --- /dev/null +++ b/tests/test_core_base.py @@ -0,0 +1,224 @@ +import pytest + +from lang2sql.core.base import BaseComponent, BaseFlow +from lang2sql.core.hooks import MemoryHook +from lang2sql.core.exceptions import ( + ComponentError, + ValidationError, + IntegrationMissingError, + ContractError, +) + + +# ------------------------- +# Fixtures: tiny components/flows +# ------------------------- + + +class AddOne(BaseComponent): + def run(self, x: int) -> int: + return x + 1 + + +class BoomValueError(BaseComponent): + def run(self, x: int) -> int: + raise ValueError("boom") + + +class BoomDomainError(BaseComponent): + def run(self, x: int) -> int: + raise ValidationError("bad sql") + + +class BoomIntegrationMissing(BaseComponent): + def run(self, x: int) -> int: + raise IntegrationMissingError("faiss", extra="faiss") + + +class FlowOk(BaseFlow): + def run(self, x: int) -> int: + return x * 2 + + +class FlowBoomDomain(BaseFlow): + def run(self, x: int) -> int: + raise ValidationError("flow bad") + + +class FlowBoomUnknown(BaseFlow): + def run(self, x: int) -> int: + raise RuntimeError("flow boom") + + +# ------------------------- +# BaseComponent tests +# ------------------------- + + +def test_base_component_emits_start_end_events(): + hook = MemoryHook() + c = AddOne(hook=hook) + + out = c(1) + assert out == 2 + + assert len(hook.events) == 2 + assert hook.events[0].name == "component.run" + assert hook.events[0].phase == "start" + assert hook.events[1].name == "component.run" + assert hook.events[1].phase == "end" + assert hook.events[1].duration_ms is not None + assert hook.events[1].duration_ms >= 0.0 + + +def test_base_component_wraps_non_domain_exception_as_component_error(): + hook = MemoryHook() + c = BoomValueError(hook=hook) + + with pytest.raises(ComponentError) as ei: + c(1) + + # error chain preserved + assert isinstance(ei.value.cause, ValueError) + assert "ValueError" in str(ei.value) or "boom" in str(ei.value) + + # events: start + error + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "ValueError" in (hook.events[1].error or "") + assert "boom" in (hook.events[1].error or "") + + +def test_base_component_preserves_domain_error_validationerror(): + hook = MemoryHook() + c = BoomDomainError(hook=hook) + + with pytest.raises(ValidationError) as ei: + c(1) + + assert "bad sql" in str(ei.value) + + # events: start + error + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "ValidationError" in (hook.events[1].error or "") + assert "bad sql" in (hook.events[1].error or "") + + +def test_base_component_preserves_domain_error_integration_missing(): + hook = MemoryHook() + c = BoomIntegrationMissing(hook=hook) + + with pytest.raises(IntegrationMissingError) as ei: + c(1) + + msg = str(ei.value) + assert "Missing optional integration: faiss" in msg + assert "lang2sql[faiss]" in msg + + # events: start + error + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "IntegrationMissingError" in (hook.events[1].error or "") + assert "faiss" in (hook.events[1].error or "") + + +def test_runcontext_contract_none_return_raises_contract_error(): + from lang2sql.core.context import RunContext + + class BadNone(BaseComponent): + def run(self, run: RunContext): + run.metadata["x"] = 1 + return None # forgot return run + + hook = MemoryHook() + c = BadNone(hook=hook) + + with pytest.raises(ContractError) as ei: + c(RunContext(query="q")) + + assert "must return RunContext" in str(ei.value) + assert "None" in str(ei.value) + + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "ContractError" in (hook.events[1].error or "") + + +def test_runcontext_contract_wrong_type_return_raises_contract_error(): + from lang2sql.core.context import RunContext + + class BadType(BaseComponent): + def run(self, run: RunContext): + return 123 # wrong + + hook = MemoryHook() + c = BadType(hook=hook) + + with pytest.raises(ContractError) as ei: + c(RunContext(query="q")) + + assert "must return RunContext" in str(ei.value) + assert "int" in str(ei.value) + + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "ContractError" in (hook.events[1].error or "") + + +# ------------------------- +# BaseFlow tests +# ------------------------- + + +def test_base_flow_emits_start_end_events(): + hook = MemoryHook() + f = FlowOk(hook=hook) + + out = f(3) + assert out == 6 + + assert len(hook.events) == 2 + assert hook.events[0].name == "flow.run" + assert hook.events[0].phase == "start" + assert hook.events[1].name == "flow.run" + assert hook.events[1].phase == "end" + assert hook.events[1].duration_ms is not None + assert hook.events[1].duration_ms >= 0.0 + + +def test_base_flow_preserves_domain_error(): + hook = MemoryHook() + f = FlowBoomDomain(hook=hook) + + with pytest.raises(ValidationError) as ei: + f(1) + + assert "flow bad" in str(ei.value) + + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "ValidationError" in (hook.events[1].error or "") + assert "flow bad" in (hook.events[1].error or "") + + +def test_base_flow_raises_unknown_error_and_emits_error_event(): + hook = MemoryHook() + f = FlowBoomUnknown(hook=hook) + + with pytest.raises(RuntimeError) as ei: + f(1) + + assert "flow boom" in str(ei.value) + + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "error" + assert "RuntimeError" in (hook.events[1].error or "") + assert "flow boom" in (hook.events[1].error or "") diff --git a/tests/test_core_context.py b/tests/test_core_context.py new file mode 100644 index 0000000..27bf2bc --- /dev/null +++ b/tests/test_core_context.py @@ -0,0 +1,84 @@ +import pytest + +from lang2sql.core.context import RunContext + + +def test_init_sets_query_and_metadata_kwargs(): + ctx = RunContext(query="hello", user_id="u1", trace_id="t1") + assert ctx.inputs["query"] == "hello" + assert ctx.metadata["user_id"] == "u1" + assert ctx.metadata["trace_id"] == "t1" + + +def test_query_property_default_empty_string(): + ctx = RunContext() + assert ctx.query == "" # inputs에 query 없으면 "" + + +def test_query_setter_updates_inputs(): + ctx = RunContext() + ctx.query = "select something" + assert ctx.inputs["query"] == "select something" + assert ctx.query == "select something" + + +def test_schema_property_initializes_dict_and_persists_reference(): + ctx = RunContext() + s = ctx.schema + assert isinstance(s, dict) + assert ctx.artifacts["schema"] is s # 같은 객체를 보장(참조 유지) + + s["selected"] = ["users", "orders"] + assert ctx.artifacts["schema"]["selected"] == ["users", "orders"] + + +def test_schema_property_overwrites_non_dict_value(): + ctx = RunContext() + ctx.artifacts["schema"] = ["not", "a", "dict"] + s = ctx.schema + assert isinstance(s, dict) + assert ctx.artifacts["schema"] == {} # dict가 아니면 {}로 교체됨 + + +def test_sql_property_get_set(): + ctx = RunContext() + assert ctx.sql == "" + ctx.sql = "SELECT 1;" + assert ctx.outputs["sql"] == "SELECT 1;" + assert ctx.sql == "SELECT 1;" + + +def test_validation_property_get_set(): + ctx = RunContext() + assert ctx.validation is None + ctx.validation = {"ok": True, "warnings": []} + assert ctx.outputs["validation"] == {"ok": True, "warnings": []} + assert ctx.validation == {"ok": True, "warnings": []} + + +def test_push_meta_appends_to_list(): + ctx = RunContext() + ctx.push_meta("events", {"name": "start"}) + ctx.push_meta("events", {"name": "end"}) + assert ctx.metadata["events"] == [{"name": "start"}, {"name": "end"}] + + +def test_push_meta_raises_when_existing_not_list(): + ctx = RunContext() + ctx.metadata["events"] = {"name": "oops"} # list가 아닌 값 + with pytest.raises(TypeError): + ctx.push_meta("events", {"name": "start"}) + + +def test_get_meta_list_returns_empty_when_missing(): + ctx = RunContext() + assert ctx.get_meta_list("missing") == [] + + +def test_get_meta_list_returns_list_or_wraps_scalar(): + ctx = RunContext() + ctx.metadata["items"] = [1, 2, 3] + assert ctx.get_meta_list("items") == [1, 2, 3] + + ctx.metadata["single"] = 42 + assert ctx.get_meta_list("single") == [42] diff --git a/tests/test_core_exceptions.py b/tests/test_core_exceptions.py new file mode 100644 index 0000000..b67716e --- /dev/null +++ b/tests/test_core_exceptions.py @@ -0,0 +1,40 @@ +import pytest + +from lang2sql.core.exceptions import ( + IntegrationMissingError, + ComponentError, + Lang2SQLError, +) + + +def test_integration_missing_error_message_includes_extra_hint(): + err = IntegrationMissingError("faiss", extra="faiss") + msg = str(err) + assert "Missing optional integration: faiss" in msg + assert "pip install 'lang2sql[faiss]'" in msg + + +def test_integration_missing_error_message_includes_hint_when_provided(): + err = IntegrationMissingError("openai", extra="openai", hint="Needed for LLM calls") + msg = str(err) + assert "Missing optional integration: openai" in msg + assert "pip install 'lang2sql[openai]'" in msg + assert "Needed for LLM calls" in msg + + +def test_component_error_wraps_component_name_and_message(): + err = ComponentError("KeywordTableRetriever", "component failed") + msg = str(err) + assert msg.startswith("[KeywordTableRetriever]") + assert "component failed" in msg + + +def test_exceptions_are_subclasses_of_base_error(): + assert issubclass(IntegrationMissingError, Lang2SQLError) + assert issubclass(ComponentError, Lang2SQLError) + + +def test_component_error_can_chain_cause(): + root = ValueError("boom") + err = ComponentError("X", "failed", cause=root) + assert err.cause is root diff --git a/tests/test_core_hooks.py b/tests/test_core_hooks.py new file mode 100644 index 0000000..60cf341 --- /dev/null +++ b/tests/test_core_hooks.py @@ -0,0 +1,74 @@ +from lang2sql.core.hooks import ( + Event, + MemoryHook, + NullHook, + summarize, + now, + ms, +) + + +def test_memory_hook_collects_events(): + hook = MemoryHook() + e1 = Event(name="x", component="c", phase="start", ts=123.0) + e2 = Event(name="x", component="c", phase="end", ts=124.0, duration_ms=1.0) + hook.on_event(e1) + hook.on_event(e2) + + assert len(hook.events) == 2 + assert hook.events[0].phase == "start" + assert hook.events[1].phase == "end" + + +def test_null_hook_does_not_crash(): + hook = NullHook() + hook.on_event( + Event(name="x", component="c", phase="start", ts=0.0) + ) # should not raise + + +def test_summarize_truncates_long_repr(): + long = "a" * 1000 + s = summarize(long, max_len=50) + assert len(s) <= 50 + assert s.endswith("...") + + +def test_now_and_ms_work(): + t0 = now() + t1 = now() + assert t1 >= t0 + d = ms(t0, t1) + assert d >= 0.0 + + +def test_memory_hook_clear_resets_events(): + hook = MemoryHook() + hook.on_event(Event(name="x", component="c", phase="start", ts=0.0)) + assert len(hook.events) == 1 + hook.clear() + assert len(hook.events) == 0 + + +def test_memory_hook_snapshot_is_copy(): + hook = MemoryHook() + hook.on_event(Event(name="x", component="c", phase="start", ts=0.0)) + + snap = hook.snapshot() + assert snap is not hook.events + assert len(snap) == 1 + + # mutate original after snapshot + hook.on_event(Event(name="x", component="c", phase="end", ts=1.0)) + assert len(hook.events) == 2 + assert len(snap) == 1 # snapshot should not change + + +class BadRepr: + def __repr__(self): + raise RuntimeError("boom") + + +def test_summarize_handles_bad_repr(): + s = summarize(BadRepr(), max_len=50) + assert "unreprable" in s diff --git a/tests/test_flows_baseline.py b/tests/test_flows_baseline.py new file mode 100644 index 0000000..72804fc --- /dev/null +++ b/tests/test_flows_baseline.py @@ -0,0 +1,256 @@ +import pytest + +from lang2sql.core.context import RunContext +from lang2sql.core.base import BaseFlow +from lang2sql.core.exceptions import ContractError +from lang2sql.flows.baseline import BaselineFlow + + +def test_requires_at_least_one_step(): + with pytest.raises(ValueError): + BaselineFlow(steps=[]) + + +def test_run_query_sets_inputs_query(): + def step(run: RunContext) -> RunContext: + run.sql = "SELECT 1;" + return run + + flow = BaselineFlow(steps=[step]) + out = flow.run_query("지난달 매출") + + assert out.inputs["query"] == "지난달 매출" + assert out.query == "지난달 매출" + assert out.sql == "SELECT 1;" + + +def test_ctx_mutate_style_step_mutates_and_returns_same_context(): + def step(run: RunContext) -> RunContext: + run.metadata["x"] = 1 + return run + + flow = BaselineFlow(steps=[step]) + run = RunContext(query="q") + out = flow.run(run) + + assert out is run + assert out.metadata["x"] == 1 + + +def test_functional_style_step_can_return_new_context(): + def step(run: RunContext) -> RunContext: + new = RunContext(query=run.query) + new.sql = "SELECT 2;" + return new + + flow = BaselineFlow(steps=[step]) + run = RunContext(query="q") + out = flow.run(run) + + assert out is not run + assert out.query == "q" + assert out.sql == "SELECT 2;" + + +def test_invalid_step_return_type_raises_contract_error(): + def bad_step(run: RunContext): + return 123 # invalid + + flow = BaselineFlow(steps=[bad_step]) + with pytest.raises(ContractError): + flow.run(RunContext(query="q")) + + +def test_step_order_is_preserved(): + def s1(run: RunContext) -> RunContext: + run.push_meta("order", "s1") + return run + + def s2(run: RunContext) -> RunContext: + run.push_meta("order", "s2") + return run + + def s3(run: RunContext) -> RunContext: + run.push_meta("order", "s3") + return run + + flow = BaselineFlow(steps=[s1, s2, s3]) + out = flow.run(RunContext(query="q")) + + assert out.get_meta_list("order") == ["s1", "s2", "s3"] + + +def test_user_can_override_pipeline_by_composing_flows_without_private_api(): + def default_step(run: RunContext) -> RunContext: + run.push_meta("order", "default") + return run + + def override_step(run: RunContext) -> RunContext: + run.push_meta("order", "override") + return run + + flow_default = BaselineFlow(steps=[default_step]) + flow_override = BaselineFlow(steps=[override_step]) + + out_default = flow_default(RunContext(query="q")) + assert out_default.get_meta_list("order") == ["default"] + + class CustomFlow(BaseFlow): + def run(self, run: RunContext) -> RunContext: + # Explicitly choose override pipeline + return flow_override(run) + + out = CustomFlow().run(RunContext(query="q")) + assert out.get_meta_list("order") == ["override"] + + +# ------------------------- +# 1) Advanced: retry patterns (NO private API) +# ------------------------- + + +def test_custom_flow_fallback_then_revalidate_makes_validation_ok(): + def gen_bad(run: RunContext) -> RunContext: + run.sql = "DROP TABLE users;" + return run + + class _V: + def __init__(self, ok: bool): + self.ok = ok + + def validate(run: RunContext) -> RunContext: + ok = "drop " not in run.sql.lower() + run.validation = _V(ok) + return run + + pipeline = BaselineFlow(steps=[gen_bad, validate]) # gen -> validate + + class FixThenRevalidateFlow(BaseFlow): + def run(self, run: RunContext) -> RunContext: + pipeline(run) + if run.validation.ok: + return run + + run.sql = "SELECT 1;" + validate(run) # explicit re-validate + return run + + out = FixThenRevalidateFlow().run(RunContext(query="q")) + assert out.sql == "SELECT 1;" + assert out.validation.ok is True + + +def test_custom_flow_retry_regenerates_sql_until_valid(): + def gen_with_attempt(run: RunContext) -> RunContext: + attempt = int(run.metadata.get("attempt", 0)) + run.metadata["attempt"] = attempt + 1 + + if attempt == 0: + run.sql = "DROP TABLE users;" + else: + run.sql = "SELECT 1;" + return run + + class _V: + def __init__(self, ok: bool): + self.ok = ok + + def validate(run: RunContext) -> RunContext: + ok = "drop " not in run.sql.lower() + run.validation = _V(ok) + return run + + class RegenerateRetryFlow(BaseFlow): + def run(self, run: RunContext) -> RunContext: + for _ in range(3): + gen_with_attempt(run) + validate(run) + if run.validation.ok: + return run + return run + + out = RegenerateRetryFlow().run(RunContext(query="q")) + assert out.sql == "SELECT 1;" + assert out.validation.ok is True + assert out.metadata["attempt"] >= 2 + + +# ------------------------- +# 2) Composition: flow as a step (subflow) +# ------------------------- + + +def test_subflow_can_be_used_as_a_step_and_mutates_same_context(): + def a1(run: RunContext) -> RunContext: + run.push_meta("trace", "a1") + return run + + def a2(run: RunContext) -> RunContext: + run.sql = "SELECT 42;" + run.push_meta("trace", "a2") + return run + + flow_a = BaselineFlow(steps=[a1, a2]) + + def b1(run: RunContext) -> RunContext: + run.push_meta("trace", "b1") + return run + + def b2(run: RunContext) -> RunContext: + run.push_meta("trace", "b2") + return run + + flow_b = BaselineFlow(steps=[b1, flow_a, b2]) + + run = RunContext(query="q") + out = flow_b.run(run) + + assert out is run + assert out.get_meta_list("trace") == ["b1", "a1", "a2", "b2"] + assert out.sql == "SELECT 42;" + + +def test_subflow_can_be_conditionally_invoked_in_custom_flow(): + def a1(run: RunContext) -> RunContext: + run.push_meta("trace", "a1") + return run + + def a2(run: RunContext) -> RunContext: + run.push_meta("trace", "a2") + return run + + flow_a = BaselineFlow(steps=[a1, a2]) + + def b1(run: RunContext) -> RunContext: + run.push_meta("trace", "b1") + return run + + def b2(run: RunContext) -> RunContext: + run.push_meta("trace", "b2") + return run + + class ConditionalFlow(BaseFlow): + def run(self, run: RunContext) -> RunContext: + b1(run) + if "use_a" in run.query: + flow_a(run) + b2(run) + return run + + out1 = ConditionalFlow().run(RunContext(query="nope")) + assert out1.get_meta_list("trace") == ["b1", "b2"] + + out2 = ConditionalFlow().run(RunContext(query="please use_a")) + assert out2.get_meta_list("trace") == ["b1", "a1", "a2", "b2"] + + +def test_none_return_raises_contract_error(): + def bad_none(run: RunContext): + run.sql = "SELECT 1;" + return None # forgot return run + + flow = BaselineFlow(steps=[bad_none]) + with pytest.raises(ContractError) as ei: + flow.run(RunContext(query="q")) + + assert "Did you forget" in str(ei.value) or "return run" in str(ei.value) diff --git a/version.py b/version.py index 7c5b6f3..87ec59a 100644 --- a/version.py +++ b/version.py @@ -18,4 +18,4 @@ - PATCH는 1로 증가합니다. """ -__version__ = "0.2.2" +__version__ = "0.3.0"