MIRA-Large: Open-Source Reproduction
Model Summary
This repository hosts the Large-scale configuration of the open-source demonstration for MIRA: Medical Time Series Foundation Model for Real-World Health Data.
The goal of this release is to offer a fully accessible and reproducible version of the MIRA framework. Unlike the internal version described in the paper which utilizes private clinical data, this model was trained exclusively on publicly available time-series datasets. It is designed to showcase the core architectural innovations of MIRA—such as CT-RoPE and Frequency-Specialized MoE—without compromising patient privacy.
Disclaimer: This is a demonstration reproduction. While it shares the same "Large" architecture and mechanisms as the original MIRA, its performance on downstream clinical tasks may differ due to the restriction to public training data.
Key Resources
- 📄 Paper: MIRA: Medical Time Series Foundation Model for Real-World Health Data (arXiv)
- 💻 Official Code: github.com/microsoft/MIRA
Technical Specifications
Architecture Highlights
This demonstration reproduces the key components of the MIRA architecture:
Continuous-Time Rotary Positional Encoding (CT-RoPE): Enables geometric time encoding specifically designed for irregularly sampled medical signals.
Configuration: The temporal unit is set to 0.1.Frequency-Specialized Mixture-of-Experts (MoE): A routing mechanism that allows experts to specialize in distinguishing between low-frequency trends and high-frequency clinical anomalies.
Neural ODE–Based Continuous Dynamics Modeling: Captures long-range physiological trajectories using continuous-time formulations, making it robust to missing data.
Model Configuration
- Size: Large
- Training Data: Publicly available medical time-series datasets only.
Usage
You can load this model using the official codebase. Please ensure you have cloned the repository first.
import torch
from MIRA.mira.models.modeling_mira import MIRAForPrediction
from MIRA.mira.models.utils_time_normalization import normalize_time_for_ctrope
# Load the pre-trained model
model = MIRAForPrediction.from_pretrained(ckpt_path).cuda()
model.eval()
# Example inference (pseudo-code)
device = next(model.parameters()).device
hist_vals = hist_vals.to(device)
hist_times = hist_times.to(device)
future_times = future_times.to(device)
cur_vals = hist_vals.clone()
cur_times = hist_times.clone()
preds_norm = []
for i in range(P):
# model input
inp_vals = cur_vals.unsqueeze(-1) # [1, L, 1]
inp_times = cur_times # [1, L]
with torch.no_grad():
out = model(
input_ids=inp_vals,
time_values=inp_times,
next_target_time_values=None, # no ODE for 1-step
return_dict=True,
)
next_norm = out.logits[:, -1, :] # [1, 1]
preds_norm.append(next_norm.squeeze(0))
next_t = future_times[:, i:i+1]
cur_vals = torch.cat([cur_vals, next_norm], dim=1)
cur_times = torch.cat([cur_times, next_t], dim=1)
preds_norm = torch.stack(preds_norm, dim=1) # [1, P]
preds = preds_norm * std[:, :, :] + mean[:, :, :]
preds = preds.squeeze(0)
print(preds)
- Downloads last month
- 8