diff --git a/examples/interpretability/interpret_playground.py b/examples/interpretability/interpret_playground.py new file mode 100644 index 000000000..1d25dd560 --- /dev/null +++ b/examples/interpretability/interpret_playground.py @@ -0,0 +1,160 @@ +"""Evaluate all interpretability methods on StageNet + MIMIC-IV dataset using comprehensiveness +and sufficiency metrics. + +This example demonstrates: +1. Loading a pre-trained StageNet model with processors and MIMIC-IV dataset +2. Computing attributions with various interpretability methods +3. Evaluating attribution faithfulness with Comprehensiveness & Sufficiency for each method +4. Presenting results in a summary table +""" + +import datetime +import argparse +from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient +from pyhealth.interpret.methods import * +from pyhealth.metrics.interpretability import evaluate_attribution +from pyhealth.models import Transformer +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 +from pyhealth.trainer import Trainer +from pyhealth.datasets.utils import load_processors +from pathlib import Path +import pandas as pd + +# python -u examples/interpretability/interpret_playground.py --device cuda:2 +def main(): + parser = argparse.ArgumentParser( + description="Comma separated list of interpretability methods to evaluate" + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Device to use for evaluation (default: cuda:0)", + ) + + """Main execution function.""" + print("=" * 70) + print("Interpretability Metrics Example: Transformer + MIMIC-IV") + print("=" * 70) + + now = datetime.datetime.now() + print(f"Start Time: {now.strftime('%Y-%m-%d %H:%M:%S')}") + + # Set path + CACHE_DIR = Path("/shared/eng/pyhealth_dka/cache/mp_mimic4") + CKPTS_DIR = Path("/shared/eng/pyhealth_dka/ckpts/mp_transformer_mimic4") + OUTPUT_DIR = Path("/shared/eng/pyhealth_dka/output/mp_transformer_mimic4") + CACHE_DIR.mkdir(parents=True, exist_ok=True) + CKPTS_DIR.mkdir(parents=True, exist_ok=True) + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + print(f"\nUsing cache dir: {CACHE_DIR}") + print(f"Using checkpoints dir: {CKPTS_DIR}") + print(f"Using output dir: {OUTPUT_DIR}") + + # Set device + device = parser.parse_args().device + print(f"\nUsing device: {device}") + + # Load MIMIC-IV dataset + print("\n Loading MIMIC-IV dataset...") + base_dataset = MIMIC4Dataset( + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", + ehr_tables=[ + "patients", + "admissions", + "diagnoses_icd", + "procedures_icd", + "labevents", + ], + cache_dir=str(CACHE_DIR), + num_workers=16, + ) + + # Apply mortality prediction task + if not (CKPTS_DIR / "input_processors.pkl").exists(): + raise FileNotFoundError(f"Input processors not found in {CKPTS_DIR}. ") + if not (CKPTS_DIR / "output_processors.pkl").exists(): + raise FileNotFoundError(f"Output processors not found in {CKPTS_DIR}. ") + input_processors, output_processors = load_processors(str(CKPTS_DIR)) + print("✓ Loaded input and output processors from checkpoint directory.") + + sample_dataset = base_dataset.set_task( + MortalityPredictionStageNetMIMIC4(), + num_workers=16, + input_processors=input_processors, + output_processors=output_processors, + ) + print(f"✓ Loaded {len(sample_dataset)} samples") + + # Split dataset and get test loader + _, _, test_dataset = split_by_patient(sample_dataset, [0.9, 0.09, 0.01], seed=233) + test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False) + print(f"✓ Test set: {len(test_dataset)} samples") + + # Initialize and load pre-trained model + print("\n Loading pre-trained Transformer model...") + model = Transformer( + dataset=sample_dataset, + embedding_dim=128, + heads=4, + dropout=0.3, + num_layers=3, + ) + + trainer = Trainer(model=model, device=device) + trainer.load_ckpt(str(CKPTS_DIR / "best.ckpt")) + model = model.to(device) + model.eval() + print(f"✓ Loaded checkpoint: {CKPTS_DIR / 'best.ckpt'}") + print(f"✓ Model moved to {device}") + + methods: dict[str, BaseInterpreter] = { + "random": RandomBaseline(model), + "shap (emb)": ShapExplainer(model, use_embeddings=True), + "shap": ShapExplainer(model, use_embeddings=False), + "lime (emb)": LimeExplainer(model, use_embeddings=True), + "lime": LimeExplainer(model, use_embeddings=False), + } + print(f"\nEvaluating methods: {list(methods.keys())}") + + res = {} + for name, method in methods.items(): + print(f"\n Initializing {name}...") + print("=" * 70) + + # Option 1: Functional API (simple one-off evaluation) + print("\nEvaluating with Functional API on full dataset...") + print("Using: evaluate_attribution(model, dataloader, method, ...)") + + results_functional = evaluate_attribution( + model, + test_loader, + method, + metrics=["comprehensiveness", "sufficiency"], + percentages=[25, 50, 99], + ) + + print("\n" + "=" * 70) + print("Dataset-Wide Results (Functional API)") + print("=" * 70) + comp = results_functional["comprehensiveness"] + suff = results_functional["sufficiency"] + print(f"\nComprehensiveness: {comp:.4f}") + print(f"Sufficiency: {suff:.4f}") + + res[name] = { + "comp": comp, + "suff": suff, + } + + print("") + print("=" * 70) + print("Summary of Results for All Methods") + print(res) + + end = datetime.datetime.now() + print(f"End Time: {end.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total Duration: {end - now}") + +if __name__ == "__main__": + main() diff --git a/pyhealth/interpret/methods/lime.py b/pyhealth/interpret/methods/lime.py index 8481c50bf..5176bfeaf 100644 --- a/pyhealth/interpret/methods/lime.py +++ b/pyhealth/interpret/methods/lime.py @@ -12,7 +12,7 @@ from .base_interpreter import BaseInterpreter -class LimeExplainer(BaseInterpreter): +class Lime(BaseInterpreter): """LIME (Local Interpretable Model-agnostic Explanations) attribution method for PyHealth models. This class implements the LIME method for computing feature attributions in @@ -488,8 +488,12 @@ def _evaluate_sample( Returns: Model prediction for the perturbed sample, shape (batch_size, ). """ - # Embed continuous (non-token) perturbed features that are still raw + inputs = inputs.copy() + if self.use_embeddings: + # Embed continuous (non-token) perturbed features that are still raw + # so forward_from_embedding receives proper embeddings. + # Token features were already embedded before perturbation. embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based LIME." @@ -506,13 +510,18 @@ def _evaluate_sample( for k, v in perturb.items() } - inputs = inputs.copy() for k in inputs.keys(): # Insert perturbed value tensor back into input tuple schema = self.model.dataset.input_processors[k].schema() inputs[k] = (*inputs[k][:schema.index("value")], perturb[k], *inputs[k][schema.index("value")+1:]) - - logits = self.model.forward_from_embedding(**inputs)["logit"] + + if self.use_embeddings: + # Values are already embedded; bypass the model's own embedding. + logits = self.model.forward_from_embedding(**inputs)["logit"] + else: + # Values are raw (token IDs / continuous floats); let the + # model's regular forward pass handle embedding internally. + logits = self.model.forward(**inputs)["logit"] # Reduce to [batch_size, ] by taking absolute difference from target class logit return (target - logits).abs().mean(dim=tuple(range(1, logits.ndim))) @@ -748,11 +757,14 @@ def _generate_baseline( for k, v in values.items(): processor = self.model.dataset.input_processors[k] - if use_embeddings and processor.is_token(): - # Token features: UNK token index as baseline + if processor.is_token(): + # Token features: UNK token (index 1) as baseline. + # When use_embeddings=True, embedding happens later in + # attribute(); when use_embeddings=False, the UNK token + # IDs are used directly as the perturbed replacement. baseline = torch.ones_like(v) else: - # Continuous features (or non-embedding mode): near-zero baseline + # Continuous features: use small neutral values (near-zero) baseline = torch.zeros_like(v) + 1e-2 baselines[k] = baseline @@ -823,3 +835,5 @@ def _map_to_input_shapes( mapped[key] = reshaped return mapped + +LimeExplainer = Lime # Alias for backward compatibility \ No newline at end of file diff --git a/pyhealth/interpret/methods/shap.py b/pyhealth/interpret/methods/shap.py index edcd02fc2..46d40a977 100644 --- a/pyhealth/interpret/methods/shap.py +++ b/pyhealth/interpret/methods/shap.py @@ -10,7 +10,7 @@ from .base_interpreter import BaseInterpreter -class ShapExplainer(BaseInterpreter): +class Shap(BaseInterpreter): """SHAP (SHapley Additive exPlanations) attribution method for PyHealth models. This class implements the SHAP method for computing feature attributions in @@ -121,8 +121,6 @@ def __init__( implement forward_from_embedding() method. """ super().__init__(model) - if not isinstance(model, Interpretable): - raise ValueError("Model must implement Interpretable interface") self.model = model self.use_embeddings = use_embeddings self.n_background_samples = n_background_samples @@ -131,13 +129,8 @@ def __init__( self.random_seed = random_seed # Validate model requirements - if use_embeddings: - assert hasattr(model, "forward_from_embedding"), ( - f"Model {type(model).__name__} must implement " - "forward_from_embedding() method to support embedding-level " - "SHAP values. Set use_embeddings=False to use " - "input-level attributions (only for continuous features)." - ) + if use_embeddings and not isinstance(model, Interpretable): + raise ValueError("Model must implement Interpretable interface or use_embeddings must be False.") # ------------------------------------------------------------------ # Public API @@ -275,7 +268,7 @@ def attribute( # (raw indices are meaningless for interpolation), while continuous # features stay raw so each raw dimension gets its own SHAP value. # Continuous features will be embedded inside _evaluate_sample(). - if self.use_embeddings: + if self.use_embeddings and isinstance(self.model, Interpretable): embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based SHAP." @@ -507,10 +500,12 @@ def _evaluate_sample( Target-class prediction scalar per batch item, shape (batch_size,). """ inputs = inputs.copy() - # For continuous (non-token) features, embed through embedding_model - # so forward_from_embedding receives proper embeddings. - # Token features were already embedded before perturbation. - if self.use_embeddings: + + if self.use_embeddings and isinstance(self.model, Interpretable): + # For continuous (non-token) features, embed through + # embedding_model so forward_from_embedding receives proper + # embeddings. Token features were already embedded before + # perturbation. embedding_model = self.model.get_embedding_model() assert embedding_model is not None, ( "Model must have an embedding model for embedding-based SHAP." @@ -534,7 +529,13 @@ def _evaluate_sample( *inputs[k][schema.index("value") + 1 :], ) - logits = self.model.forward_from_embedding(**inputs)["logit"] + if self.use_embeddings and isinstance(self.model, Interpretable): + # Values are already embedded; bypass the model's own embedding. + logits = self.model.forward_from_embedding(**inputs)["logit"] + else: + # Values are raw (token IDs / continuous floats); let the + # model's regular forward pass handle embedding internally. + logits = self.model.forward(**inputs)["logit"] return self._extract_target_prediction(logits, target) @@ -666,9 +667,11 @@ def _generate_background_samples( baselines = {} for k, v in values.items(): - if use_embeddings and self.model.dataset.input_processors[k].is_token(): + if self.model.dataset.input_processors[k].is_token(): # Token features: UNK token (index 1) as baseline. - # Embedding happens later in attribute(). + # When use_embeddings=True, embedding happens later in + # attribute(); when use_embeddings=False, the UNK token + # IDs are used directly as the perturbed replacement. baseline = torch.ones_like(v) else: # Continuous features: use small neutral values (near-zero) @@ -771,4 +774,6 @@ def _map_to_input_shapes( mapped[key] = reshaped - return mapped \ No newline at end of file + return mapped + +ShapExplainer = Shap # Alias for backward compatibility \ No newline at end of file diff --git a/pyhealth/processors/base_processor.py b/pyhealth/processors/base_processor.py index 39292df43..06bbe0c7c 100644 --- a/pyhealth/processors/base_processor.py +++ b/pyhealth/processors/base_processor.py @@ -184,6 +184,16 @@ def retain(self, tokens: set[str]): def add(self, tokens: set[str]): """Add specified vocabularies to the processor.""" pass + + @abstractmethod + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + pass + + @abstractmethod + def vocab_size(self) -> int: + """Return the size of the processor's vocabulary.""" + pass class TemporalFeatureProcessor(FeatureProcessor): @@ -250,4 +260,4 @@ def process(self, value) -> dict[str, torch.Tensor]: def schema(self) -> tuple[str, ...]: """Standardised schema: at minimum ``('value', 'time')``.""" - return ("value", "time") \ No newline at end of file + return ("value", "time") diff --git a/pyhealth/processors/deep_nested_sequence_processor.py b/pyhealth/processors/deep_nested_sequence_processor.py index 2d67633eb..24683f54b 100644 --- a/pyhealth/processors/deep_nested_sequence_processor.py +++ b/pyhealth/processors/deep_nested_sequence_processor.py @@ -107,6 +107,10 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + def process(self, value: List[List[List[Any]]]) -> torch.Tensor: """Process deep nested sequence into padded 3D tensor. @@ -169,6 +173,10 @@ def process(self, value: List[List[List[Any]]]) -> torch.Tensor: return torch.tensor(encoded_groups, dtype=torch.long) + def vocab_size(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self) -> int: """Return max inner length (embedding dimension) for unified API.""" return self._max_inner_len diff --git a/pyhealth/processors/nested_sequence_processor.py b/pyhealth/processors/nested_sequence_processor.py index 461575621..c03c800d0 100644 --- a/pyhealth/processors/nested_sequence_processor.py +++ b/pyhealth/processors/nested_sequence_processor.py @@ -103,6 +103,10 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + def process(self, value: List[List[Any]]) -> torch.Tensor: """Process nested sequence into padded 2D tensor. @@ -146,6 +150,10 @@ def process(self, value: List[List[Any]]) -> torch.Tensor: return torch.tensor(encoded_sequences, dtype=torch.long) + def vocab_size(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self) -> int: """Return max inner length (embedding dimension) for unified API.""" return self._max_inner_len diff --git a/pyhealth/processors/sequence_processor.py b/pyhealth/processors/sequence_processor.py index d7a7b1ddf..7792709bb 100644 --- a/pyhealth/processors/sequence_processor.py +++ b/pyhealth/processors/sequence_processor.py @@ -68,6 +68,14 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + + def vocab_size(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self): return len(self.code_vocab) diff --git a/pyhealth/processors/stagenet_processor.py b/pyhealth/processors/stagenet_processor.py index ba6bd59aa..604376ec1 100644 --- a/pyhealth/processors/stagenet_processor.py +++ b/pyhealth/processors/stagenet_processor.py @@ -139,6 +139,10 @@ def add(self, tokens: set[str]): self.code_vocab[token] = i i += 1 + def tokens(self) -> set[str]: + """Return the set of tokens in the processor's vocabulary.""" + return set(self.code_vocab.keys()) + def process( self, value: Tuple[Optional[List], List] ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: @@ -217,6 +221,10 @@ def _encode_nested_codes(self, nested_codes: List[List[str]]) -> torch.Tensor: return torch.tensor(encoded_sequences, dtype=torch.long) + def vocab_size(self) -> int: + """Return the size of the processor's vocabulary.""" + return len(self.code_vocab) + def size(self) -> int: """Return vocabulary size.""" return len(self.code_vocab)