Skip to content

Comments

Add MultimodalJambaEHR with UnifiedMultimodalEmbedding and TemporalFeatureProcessor#874

Open
joshuasteier wants to merge 1 commit intosunlabuiuc:masterfrom
Multimodal-PyHealth:feature/multimodal-jamba
Open

Add MultimodalJambaEHR with UnifiedMultimodalEmbedding and TemporalFeatureProcessor#874
joshuasteier wants to merge 1 commit intosunlabuiuc:masterfrom
Multimodal-PyHealth:feature/multimodal-jamba

Conversation

@joshuasteier
Copy link
Collaborator

Contributor

  • Josh Steier

Description

Adds the multimodal Jamba backbone, unified multimodal embedding layer, and temporal feature processor abstract base class for the multimodal mortality prediction pipeline.

Files

File Description
pyhealth/models/multimodal_jamba.py MultimodalJambaEHR, UnifiedMultimodalEmbedding, JambaBackbone with 3-tier Mamba backend (mamba-ssm CUDA → PyHealth MambaBlock → pure PyTorch parallel scan)
pyhealth/processors/temporal_feature_processor.py TemporalFeatureProcessor abstract base class — Rian's text processor and William's timeseries processor should inherit from this
tests/core/test_multimodal_jamba.py 36 unit tests

Architecture

Per-modality encoders (B, S_i, E') + timestamps (B, S_i)
    → UnifiedMultimodalEmbedding
        - Sinusoidal time embeddings
        - Learnable modality-type embeddings (IMAGE/TEXT/TIMESERIES/SEQUENCE)
        - Learnable missing-modality tokens
        - Optional [CLS] token
    → (B, S_total, E') concatenated sequence
    → JambaBackbone (interleaved Transformer + Mamba layers)
    → Pooling (CLS / mean / last)
    → FC classification head

Missing Modality Handling (per Feb 16 meeting notes)

  • If a modality is absent, a learnable missing token + modality-type embedding is substituted
  • Model degrades gracefully: EHR-only, text-only, any combination works

Mamba Backend Priority

  1. mamba-ssm CUDA kernels (fastest, pip install mamba-ssm)
  2. PyHealth MambaBlock from ehr_mamba.py
  3. Pure PyTorch parallel scan fallback (no deps needed)

On the campus cluster with mamba-ssm installed, it auto-selects CUDA. No code changes needed.

Testing

python pyhealth/models/multimodal_jamba.py                    # 7 smoke tests
python -m unittest tests/core/test_multimodal_jamba.py -v      # 36 unit tests

All 36 tests pass.

Usage

from pyhealth.models.multimodal_jamba import MultimodalJambaEHR, ModalityType

model = MultimodalJambaEHR(
    embedding_dim=128,
    num_transformer_layers=2,
    num_mamba_layers=6,
    heads=4,
    num_classes=2,
)

inputs = {
    ModalityType.IMAGE: (image_embeddings, image_times),
    ModalityType.TEXT: (text_embeddings, text_times),
    ModalityType.TIMESERIES: (ts_embeddings, ts_times),
    ModalityType.SEQUENCE: (code_embeddings, code_times),
}

out = model(inputs, labels=labels)
out["loss"].backward()

Note on BaseModel

Currently inherits nn.Module instead of PyHealth BaseModel. This is intentional — BaseModel integration requires the full processor → dataset → task pipeline (William's multimodal task + processors). Will be updated once those pieces land.

@joshuasteier joshuasteier force-pushed the feature/multimodal-jamba branch from 4095ec6 to cd82d07 Compare February 24, 2026 20:02
@Rian354 Rian354 force-pushed the feature/multimodal-jamba branch 2 times, most recently from 0b6218b to cd82d07 Compare February 25, 2026 00:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant