From 935b4b16d52829e8144d8b6df914835ac3987a3c Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 14 Feb 2026 18:07:47 -0600 Subject: [PATCH 1/6] Simpler name --- pyhealth/interpret/methods/lime.py | 4 +++- pyhealth/interpret/methods/shap.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 8481c50bf..47df3859a 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -12,7 +12,7 @@ from .base_interpreter import BaseInterpreter -class LimeExplainer(BaseInterpreter): +class Lime(BaseInterpreter): """LIME (Local Interpretable Model-agnostic Explanations) attribution method for PyHealth models. This class implements the LIME method for computing feature attributions in @@ -823,3 +823,5 @@ def _map_to_input_shapes( mapped[key] = reshaped return mapped + +type LimeExplainer = Lime # Alias for backward compatibility \ No newline at end of file diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index edcd02fc2..f67864a63 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -10,7 +10,7 @@ from .base_interpreter import BaseInterpreter -class ShapExplainer(BaseInterpreter): +class Shap(BaseInterpreter): """SHAP (SHapley Additive exPlanations) attribution method for PyHealth models. This class implements the SHAP method for computing feature attributions in @@ -771,4 +771,6 @@ def _map_to_input_shapes( mapped[key] = reshaped - return mapped \ No newline at end of file + return mapped + +type ShapExplainer = Shap # Alias for backward compatibility \ No newline at end of file From 8282d2dc73f81b53d1b8c0272477d2e2a0b1e430 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 14 Feb 2026 18:14:13 -0600 Subject: [PATCH 2/6] Add API for processors --- pyhealth/processors/base_processor.py | 10 ++++++++++ pyhealth/processors/deep_nested_sequence_processor.py | 8 ++++++++ pyhealth/processors/nested_sequence_processor.py | 8 ++++++++ pyhealth/processors/sequence_processor.py | 8 ++++++++ pyhealth/processors/stagenet_processor.py | 8 ++++++++ 5 files changed, 42 insertions(+) diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 48cbe26ae..5611cccfa 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -167,4 +167,14 @@ def retain(self, tokens: set[str]): @abstractmethod def add(self, tokens: set[str]): """Add specified vocabularies to the processor.""" + pass + + @abstractmethod + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + pass + + @abstractmethod + def num_tokens(self) -> int: + """Return the size of the processor's vocabulary.""" pass \ No newline at end of file diff --git a/pyhealth/processors/deep_nested_sequence_processor.py b/pyhealth/processors/deep_nested_sequence_processor.py index 2d67633eb..948c15c8d 100644 --- a/pyhealth/processors/deep_nested_sequence_processor.py +++ b/pyhealth/processors/deep_nested_sequence_processor.py @@ -107,6 +107,10 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + def process(self, value: List[List[List[Any]]]) -> torch.Tensor: """Process deep nested sequence into padded 3D tensor. @@ -169,6 +173,10 @@ def process(self, value: List[List[List[Any]]]) -> torch.Tensor: return torch.tensor(encoded_groups, dtype=torch.long) + def num_tokens(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self) -> int: """Return max inner length (embedding dimension) for unified API.""" return self._max_inner_len diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index 461575621..0580d3536 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -103,6 +103,10 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + def process(self, value: List[List[Any]]) -> torch.Tensor: """Process nested sequence into padded 2D tensor. @@ -146,6 +150,10 @@ def process(self, value: List[List[Any]]) -> torch.Tensor: return torch.tensor(encoded_sequences, dtype=torch.long) + def num_tokens(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self) -> int: """Return max inner length (embedding dimension) for unified API.""" return self._max_inner_len diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index d7a7b1ddf..14d7a2000 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -68,6 +68,14 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + + def num_tokens(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self): return len(self.code_vocab) diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index cce8819c5..00e88ad04 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -139,6 +139,10 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + def process( self, value: Tuple[Optional[List], List] ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: @@ -217,6 +221,10 @@ def _encode_nested_codes(self, nested_codes: List[List[str]]) -> torch.Tensor: return torch.tensor(encoded_sequences, dtype=torch.long) + def num_tokens(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self) -> int: """Return vocabulary size.""" return len(self.code_vocab) From 259dac9365d2767d08f47e116e7521e6225d05a8 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Mon, 16 Feb 2026 06:27:58 -0600 Subject: [PATCH 3/6] Allow Shap to be truely black box --- pyhealth/interpret/methods/shap.py | 39 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index f67864a63..46d40a977 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -121,8 +121,6 @@ def __init__( implement forward_from_embedding() method. """ super().__init__(model) - if not isinstance(model, Interpretable): - raise ValueError("Model must implement Interpretable interface") self.model = model self.use_embeddings = use_embeddings self.n_background_samples = n_background_samples @@ -131,13 +129,8 @@ def __init__( self.random_seed = random_seed # Validate model requirements - if use_embeddings: - assert hasattr(model, "forward_from_embedding"), ( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() method to support embedding-level " - "SHAP values. Set use_embeddings=False to use " - "input-level attributions (only for continuous features)." - ) + if use_embeddings and not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface or use_embeddings must be False.") # ------------------------------------------------------------------ # Public API @@ -275,7 +268,7 @@ def attribute( # (raw indices are meaningless for interpolation), while continuous # features stay raw so each raw dimension gets its own SHAP value. # Continuous features will be embedded inside _evaluate_sample(). - if self.use_embeddings: + if self.use_embeddings and isinstance(self.model, Interpretable): embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based SHAP." @@ -507,10 +500,12 @@ def _evaluate_sample( Target-class prediction scalar per batch item, shape (batch_size,). """ inputs = inputs.copy() - # For continuous (non-token) features, embed through embedding_model - # so forward_from_embedding receives proper embeddings. - # Token features were already embedded before perturbation. - if self.use_embeddings: + + if self.use_embeddings and isinstance(self.model, Interpretable): + # For continuous (non-token) features, embed through + # embedding_model so forward_from_embedding receives proper + # embeddings. Token features were already embedded before + # perturbation. embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based SHAP." @@ -534,7 +529,13 @@ def _evaluate_sample( *inputs[k][schema.index("value") + 1 :], ) - logits = self.model.forward_from_embedding(**inputs)["logit"] + if self.use_embeddings and isinstance(self.model, Interpretable): + # Values are already embedded; bypass the model's own embedding. + logits = self.model.forward_from_embedding(**inputs)["logit"] + else: + # Values are raw (token IDs / continuous floats); let the + # model's regular forward pass handle embedding internally. + logits = self.model.forward(**inputs)["logit"] return self._extract_target_prediction(logits, target) @@ -666,9 +667,11 @@ def _generate_background_samples( baselines = {} for k, v in values.items(): - if use_embeddings and self.model.dataset.input_processors[k].is_token(): + if self.model.dataset.input_processors[k].is_token(): # Token features: UNK token (index 1) as baseline. - # Embedding happens later in attribute(). + # When use_embeddings=True, embedding happens later in + # attribute(); when use_embeddings=False, the UNK token + # IDs are used directly as the perturbed replacement. baseline = torch.ones_like(v) else: # Continuous features: use small neutral values (near-zero) @@ -773,4 +776,4 @@ def _map_to_input_shapes( return mapped -type ShapExplainer = Shap # Alias for backward compatibility \ No newline at end of file +ShapExplainer = Shap # Alias for backward compatibility \ No newline at end of file From 3c330f7f4ace996896a9af4c39b9a64744c15f15 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Mon, 16 Feb 2026 06:44:27 -0600 Subject: [PATCH 4/6] Enhance Lime model handling for embeddings and perturbations --- pyhealth/interpret/methods/lime.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 47df3859a..5176bfeaf 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -488,8 +488,12 @@ def _evaluate_sample( Returns: Model prediction for the perturbed sample, shape (batch_size, ). """ - # Embed continuous (non-token) perturbed features that are still raw + inputs = inputs.copy() + if self.use_embeddings: + # Embed continuous (non-token) perturbed features that are still raw + # so forward_from_embedding receives proper embeddings. + # Token features were already embedded before perturbation. embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based LIME." @@ -506,13 +510,18 @@ def _evaluate_sample( for k, v in perturb.items() } - inputs = inputs.copy() for k in inputs.keys(): # Insert perturbed value tensor back into input tuple schema = self.model.dataset.input_processors[k].schema() inputs[k] = (*inputs[k][:schema.index("value")], perturb[k], *inputs[k][schema.index("value")+1:]) - - logits = self.model.forward_from_embedding(**inputs)["logit"] + + if self.use_embeddings: + # Values are already embedded; bypass the model's own embedding. + logits = self.model.forward_from_embedding(**inputs)["logit"] + else: + # Values are raw (token IDs / continuous floats); let the + # model's regular forward pass handle embedding internally. + logits = self.model.forward(**inputs)["logit"] # Reduce to [batch_size, ] by taking absolute difference from target class logit return (target - logits).abs().mean(dim=tuple(range(1, logits.ndim))) @@ -748,11 +757,14 @@ def _generate_baseline( for k, v in values.items(): processor = self.model.dataset.input_processors[k] - if use_embeddings and processor.is_token(): - # Token features: UNK token index as baseline + if processor.is_token(): + # Token features: UNK token (index 1) as baseline. + # When use_embeddings=True, embedding happens later in + # attribute(); when use_embeddings=False, the UNK token + # IDs are used directly as the perturbed replacement. baseline = torch.ones_like(v) else: - # Continuous features (or non-embedding mode): near-zero baseline + # Continuous features: use small neutral values (near-zero) baseline = torch.zeros_like(v) + 1e-2 baselines[k] = baseline @@ -824,4 +836,4 @@ def _map_to_input_shapes( return mapped -type LimeExplainer = Lime # Alias for backward compatibility \ No newline at end of file +LimeExplainer = Lime # Alias for backward compatibility \ No newline at end of file From 5bbd55f2252c5293d37b50fd7c8ff45bc4713240 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Mon, 16 Feb 2026 06:58:19 -0600 Subject: [PATCH 5/6] Add playground for test new interpret methods --- .../interpretability/interpret_playground.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 examples/interpretability/interpret_playground.py diff --git a/examples/interpretability/interpret_playground.py b/examples/interpretability/interpret_playground.py new file mode 100644 index 000000000..1d25dd560 --- /dev/null +++ b/examples/interpretability/interpret_playground.py @@ -0,0 +1,160 @@ +"""Evaluate all interpretability methods on StageNet + MIMIC-IV dataset using comprehensiveness +and sufficiency metrics. + +This example demonstrates: +1. Loading a pre-trained StageNet model with processors and MIMIC-IV dataset +2. Computing attributions with various interpretability methods +3. Evaluating attribution faithfulness with Comprehensiveness & Sufficiency for each method +4. Presenting results in a summary table +""" + +import datetime +import argparse +from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient +from pyhealth.interpret.methods import * +from pyhealth.metrics.interpretability import evaluate_attribution +from pyhealth.models import Transformer +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +from pyhealth.datasets.utils import load_processors +from pathlib import Path +import pandas as pd + +# python -u examples/interpretability/interpret_playground.py --device cuda:2 +def main(): + parser = argparse.ArgumentParser( + description="Comma separated list of interpretability methods to evaluate" + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to use for evaluation (default: cuda:0)", + ) + + """Main execution function.""" + print("=" * 70) + print("Interpretability Metrics Example: Transformer + MIMIC-IV") + print("=" * 70) + + now = datetime.datetime.now() + print(f"Start Time: {now.strftime('%Y-%m-%d %H:%M:%S')}") + + # Set path + CACHE_DIR = Path("/shared/eng/pyhealth_dka/cache/mp_mimic4") + CKPTS_DIR = Path("/shared/eng/pyhealth_dka/ckpts/mp_transformer_mimic4") + OUTPUT_DIR = Path("/shared/eng/pyhealth_dka/output/mp_transformer_mimic4") + CACHE_DIR.mkdir(parents=True, exist_ok=True) + CKPTS_DIR.mkdir(parents=True, exist_ok=True) + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + print(f"\nUsing cache dir: {CACHE_DIR}") + print(f"Using checkpoints dir: {CKPTS_DIR}") + print(f"Using output dir: {OUTPUT_DIR}") + + # Set device + device = parser.parse_args().device + print(f"\nUsing device: {device}") + + # Load MIMIC-IV dataset + print("\n Loading MIMIC-IV dataset...") + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + cache_dir=str(CACHE_DIR), + num_workers=16, + ) + + # Apply mortality prediction task + if not (CKPTS_DIR / "input_processors.pkl").exists(): + raise FileNotFoundError(f"Input processors not found in {CKPTS_DIR}. ") + if not (CKPTS_DIR / "output_processors.pkl").exists(): + raise FileNotFoundError(f"Output processors not found in {CKPTS_DIR}. ") + input_processors, output_processors = load_processors(str(CKPTS_DIR)) + print("✓ Loaded input and output processors from checkpoint directory.") + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=16, + input_processors=input_processors, + output_processors=output_processors, + ) + print(f"✓ Loaded {len(sample_dataset)} samples") + + # Split dataset and get test loader + _, _, test_dataset = split_by_patient(sample_dataset, [0.9, 0.09, 0.01], seed=233) + test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False) + print(f"✓ Test set: {len(test_dataset)} samples") + + # Initialize and load pre-trained model + print("\n Loading pre-trained Transformer model...") + model = Transformer( + dataset=sample_dataset, + embedding_dim=128, + heads=4, + dropout=0.3, + num_layers=3, + ) + + trainer = Trainer(model=model, device=device) + trainer.load_ckpt(str(CKPTS_DIR / "best.ckpt")) + model = model.to(device) + model.eval() + print(f"✓ Loaded checkpoint: {CKPTS_DIR / 'best.ckpt'}") + print(f"✓ Model moved to {device}") + + methods: dict[str, BaseInterpreter] = { + "random": RandomBaseline(model), + "shap (emb)": ShapExplainer(model, use_embeddings=True), + "shap": ShapExplainer(model, use_embeddings=False), + "lime (emb)": LimeExplainer(model, use_embeddings=True), + "lime": LimeExplainer(model, use_embeddings=False), + } + print(f"\nEvaluating methods: {list(methods.keys())}") + + res = {} + for name, method in methods.items(): + print(f"\n Initializing {name}...") + print("=" * 70) + + # Option 1: Functional API (simple one-off evaluation) + print("\nEvaluating with Functional API on full dataset...") + print("Using: evaluate_attribution(model, dataloader, method, ...)") + + results_functional = evaluate_attribution( + model, + test_loader, + method, + metrics=["comprehensiveness", "sufficiency"], + percentages=[25, 50, 99], + ) + + print("\n" + "=" * 70) + print("Dataset-Wide Results (Functional API)") + print("=" * 70) + comp = results_functional["comprehensiveness"] + suff = results_functional["sufficiency"] + print(f"\nComprehensiveness: {comp:.4f}") + print(f"Sufficiency: {suff:.4f}") + + res[name] = { + "comp": comp, + "suff": suff, + } + + print("") + print("=" * 70) + print("Summary of Results for All Methods") + print(res) + + end = datetime.datetime.now() + print(f"End Time: {end.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total Duration: {end - now}") + +if __name__ == "__main__": + main() From 639602761475c3a22236f1762cd00b219f6248f1 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sun, 22 Feb 2026 06:30:56 -0600 Subject: [PATCH 6/6] rename method name --- pyhealth/processors/base_processor.py | 2 +- pyhealth/processors/deep_nested_sequence_processor.py | 2 +- pyhealth/processors/nested_sequence_processor.py | 2 +- pyhealth/processors/sequence_processor.py | 2 +- pyhealth/processors/stagenet_processor.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 5611cccfa..932f81a69 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -175,6 +175,6 @@ def tokens(self) -> set[str]: pass @abstractmethod - def num_tokens(self) -> int: + def vocab_size(self) -> int: """Return the size of the processor's vocabulary.""" pass \ No newline at end of file diff --git a/pyhealth/processors/deep_nested_sequence_processor.py b/pyhealth/processors/deep_nested_sequence_processor.py index 948c15c8d..24683f54b 100644 --- a/pyhealth/processors/deep_nested_sequence_processor.py +++ b/pyhealth/processors/deep_nested_sequence_processor.py @@ -173,7 +173,7 @@ def process(self, value: List[List[List[Any]]]) -> torch.Tensor: return torch.tensor(encoded_groups, dtype=torch.long) - def num_tokens(self) -> int: + def vocab_size(self) -> int: """Return the size of the processor's vocabulary.""" return len(self.code_vocab) diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index 0580d3536..c03c800d0 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -150,7 +150,7 @@ def process(self, value: List[List[Any]]) -> torch.Tensor: return torch.tensor(encoded_sequences, dtype=torch.long) - def num_tokens(self) -> int: + def vocab_size(self) -> int: """Return the size of the processor's vocabulary.""" return len(self.code_vocab) diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index 14d7a2000..7792709bb 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -72,7 +72,7 @@ def tokens(self) -> set[str]: """Return the set of tokens in the processor's vocabulary.""" return set(self.code_vocab.keys()) - def num_tokens(self) -> int: + def vocab_size(self) -> int: """Return the size of the processor's vocabulary.""" return len(self.code_vocab) diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index 00e88ad04..b3b4473e9 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -221,7 +221,7 @@ def _encode_nested_codes(self, nested_codes: List[List[str]]) -> torch.Tensor: return torch.tensor(encoded_sequences, dtype=torch.long) - def num_tokens(self) -> int: + def vocab_size(self) -> int: """Return the size of the processor's vocabulary.""" return len(self.code_vocab)