MorphoCLIP Training Pipeline
In Plain English: MorphoCLIP learns to match microscopy images with text descriptions of biological treatments. It takes cell images processed through a vision AI (DINOv3) and text descriptions processed through a language AI (ModernBERT), then trains small βadapterβ networks to place matching image-text pairs near each other in a shared mathematical space. See the Glossary for term definitions.
Architecture
Status
| Component | Status |
|---|---|
| DINOv3 feature extraction | Done |
| Text encoder + projection head | Done |
| Dataset + collate + splits | Done |
| Text embedding caching (768-d) | Done |
| CrossChannelFormer | Done |
| Image projection head | Done |
| InfoNCE loss | Done |
| CWCL loss (soft labels) | Done |
| CWA batch correction | Done |
| Training loop + checkpointing | Done |
| Retrieval metrics (R@k, MRR) | Done |
| Perturbation matching metrics | To build |
| LoRA fine-tuning | To build |
Architecture Overview
IMAGE BRANCH TEXT BRANCH
βββββββββββββ βββββββββββ
5 fluorescence channels/site PerturbationInfo
β β
DINOv3 ViT-L/16 (frozen) PromptBuilder (verbose templates)
β β
CLS tokens (5, 1024)/site BioClinical ModernBERT (frozen)
β β
βββββββββββββββββββββββ CLS token (768-d) β cached
β CrossChannelFormer β β
β (1-layer xformer) β ProjectionHead (768β512)
βββββββββββ¬ββββββββββββ β
(1, 1024)/site 512-d L2-normalized
β
Image ProjectionHead (1024β512)
β
512-d L2-normalized
β β
ββββββββββββββ¬ββββββββββββββββββββββββββββ
β
CWCL Contrastive Loss
+ CWA batch correctionComponents to Build
1. CrossChannelFormer (src/morphoclip/models/cross_channel_former.py)
Aggregates 5 per-channel CLS tokens into a single image representation per site.
Why needed: DINOv3 processes each of the 5 fluorescence channels independently (replicated to pseudo-RGB). The CrossChannelFormer learns cross-channel interactions β e.g., how mitochondrial morphology relates to actin cytoskeleton structure β producing one unified 1024-d token per site.
Input: (batch, 5, 1024) β 5 channel CLS tokens
Output: (batch, 1024) β single aggregated representationDesign:
- 1-layer transformer encoder (nn.TransformerEncoder, default; configurable via
ccf_layers) - 5 learnable channel-type embeddings (one per: Mitochondria, Actin, Golgi, ER, DNA) added to input tokens
- Aggregation: mean-pool over the 5 output tokens (or use a learnable CLS query token)
- 4 attention heads (configurable via
ccf_heads) - ~4M parameters
Alternatives considered:
- Simple mean pooling (no learnable cross-channel interaction β baseline ablation)
- Single attention layer (fewer params but may underfit)
2. Image Projection Head (src/morphoclip/models/image_projection_head.py)
Maps aggregated image representation to shared contrastive space.
Input: (batch, 1024) β CrossChannelFormer output
Output: (batch, 512) β L2-normalized embeddingDesign: Reuse existing ProjectionHead class with input_dim=1024:
image_proj = ProjectionHead(input_dim=1024, hidden_dim=512, output_dim=512)3. Site Aggregation
Each well has multiple imaging sites (up to 9). Sites must be aggregated into a single well-level embedding before loss computation.
Options (in order of simplicity):
- Mean pooling β average site embeddings (use
site_maskfrom collate_fn to exclude padding) - Attention pooling β learnable weighted average over sites
- Max pooling β element-wise max
Recommendation: Start with masked mean pooling. Itβs simple, proven, and the site-level variation within a well is small.
4. CWCL Loss (src/morphoclip/training/losses.py)
Continuously Weighted Contrastive Loss β handles soft positives (multiple wells with the same perturbation).
Standard CLIP uses hard positive pairs (1 image β 1 text). In Cell Painting, many wells share the same perturbation, creating multiple valid positives per text description. CWCL assigns continuous similarity weights instead of binary labels.
Input: image_embeds (B, 512), text_embeds (B, 512)
Output: scalar lossAlgorithm:
-
Compute cosine similarity matrix:
S = image_embeds @ text_embeds.Tβ(B, B) -
Build soft label matrix
WwhereW[i,j]reflects how similar perturbation_i is to perturbation_j:- Same
broad_sampleβ weight 1.0 - Same target gene (different compound) β weight 0.5β0.8
- Different perturbation β weight 0.0
- Same
-
Temperature-scaled cross-entropy with soft labels:
loss_i2t = -sum(W[i,:] * log_softmax(S[i,:] / tau)) loss_t2i = -sum(W[:,j] * log_softmax(S[:,j] / tau)) loss = (loss_i2t + loss_t2i) / 2 -
Learnable temperature
tau(initialized to 0.07, log-parameterized)
Fallback: Start with standard symmetric InfoNCE (CLIP loss) as a baseline, then add soft weighting.
5. CWA Batch Correction (src/morphoclip/training/batch_correction.py)
Cross-Well Alignment β removes plate-specific technical artifacts from image embeddings during training.
Cell Painting images have significant batch effects (plate-to-plate variation from staining, imaging conditions). Without correction, the model learns to cluster by plate rather than by perturbation.
Algorithm:
- Within each training batch, group image embeddings by plate
- Compute per-plate mean embedding:
mu_p = mean(embeds[plate == p]) - Subtract plate mean:
embeds_corrected = embeds - mu_p - Re-normalize to unit sphere:
embeds_corrected = L2_norm(embeds_corrected)
When applied: After the image projection head, before loss computation. Text embeddings are not corrected (no batch effect in text).
Requires: Batch sampler that ensures each batch contains wells from multiple plates (see Data Loading below).
6. Full Image Encoder (src/morphoclip/models/image_encoder.py)
Wraps the full image-side pipeline:
class MorphoCLIPImageEncoder(nn.Module):
def __init__(self):
self.cross_channel_former = CrossChannelFormer(dim=1024, depth=1, heads=4)
self.projection_head = ProjectionHead(input_dim=1024, output_dim=512)
def forward(self, features, site_mask):
"""
Args:
features: (B, max_sites, 5, 1024) β padded DINOv3 CLS tokens
site_mask: (B, max_sites) β True for real sites
Returns:
(B, 512) β L2-normalized well embeddings
"""
B, S, C, D = features.shape
# Merge batch and site dims
x = features.view(B * S, C, D) # (B*S, 5, 1024)
x = self.cross_channel_former(x) # (B*S, 1024)
x = x.view(B, S, -1) # (B, S, 1024)
# Masked mean pooling over sites
mask = site_mask.unsqueeze(-1).float() # (B, S, 1)
x = (x * mask).sum(dim=1) / mask.sum(dim=1) # (B, 1024)
x = self.projection_head(x) # (B, 512) L2-normalized
return xTraining Loop Design
Data Loading
# 1. Build dataset
metadata = MetadataIndex.from_config(config_path)
dataset = MorphoCLIPDataset(
feature_dir=Path("data/features"),
metadata=metadata,
plates=train_plates,
mode="features",
text_level="full",
exclude_controls=True,
)
# 2. Split (stratified by broad_sample across all pert types)
train_set, val_set, test_set = create_splits(
dataset,
strategy="pert_type",
val_fraction=0.1, # 10% val, 10% test, 80% train
)
# 3. DataLoader with custom collate
train_loader = DataLoader(train_set, batch_size=512, shuffle=True, collate_fn=collate_fn, num_workers=0)Split strategy: The pert_type strategy uses only local metadata (data/metadata/).
Each unique broad_sample is deterministically hashed to train/val/test, so all wells
sharing the same perturbation land in the same split. Compounds, CRISPR, and ORF are all
represented in every split, ensuring the model trains and evaluates on all perturbation
modalities. Benchmark-aligned strategies (cpjump1_official_*, cellclip_cpjump_style)
are available in benchmark.splits for comparison with published baselines.
Note: CWA is disabled by default (use_cwa: false in base config). When enabled, each batch needs wells from multiple plates. With 51 plates and batch_size=512, random shuffling provides sufficient plate diversity.
Data loading: With preload: true (default), all features are loaded into RAM at startup with num_workers: 0. Set preload: false and increase num_workers if RAM is constrained.
Text Embedding Strategy
Text embeddings are pre-computed and cached (768-d raw BERT features). During training:
- Load cached text embeddings at startup using
load_cached_text_features() - Look up cached 768-d embedding by
broad_sampleID - Pass through the (trainable) text
ProjectionHeadto get 512-d - This avoids running BERT forward pass every epoch (~10x faster)
Training Step Pseudocode
for batch in train_loader:
# Image branch
image_embeds = image_encoder(batch["features"], batch["site_mask"]) # (B, 512)
# Text branch (from cache)
text_raw = lookup_cached_embeddings(batch["pert_info"]) # (B, 768)
text_embeds = text_projection_head(text_raw) # (B, 512)
# Batch correction (CWA)
image_embeds = cross_well_align(image_embeds, batch["plates"])
# Loss
loss = cwcl_loss(image_embeds, text_embeds, batch["pert_info"])
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()Hyperparameters
| Parameter | Value | Notes |
|---|---|---|
| Batch size | 512 | More negatives for contrastive learning |
| Learning rate | 1e-4 | AdamW, weight_decay=0.2 |
| LR schedule | Cosine annealing with warmup | 100 warmup steps |
| Temperature (tau) | Learnable (LogitScaleModule) | Log-parameterized, clamped to ln(100) |
| Epochs | 100 | Early stopping on val loss |
| Gradient clipping | max_norm=1.0 | Prevent training instability |
| Mixed precision | fp16 (torch.amp) | Halves VRAM, faster matmuls |
| Projection dropout | 0.3 | Stronger regularization in projection heads |
| CWA | Disabled by default | Enable with use_cwa: true |
Config variants:
configs/train/base.yamlβ default config (CCF aggregation, batch_size=512)configs/train/mean_pool.yamlβ replaces CCF with simple mean pooling (ablation)configs/train/ddp.yamlβ multi-GPU DDP (batch_size=128/GPU, num_workers=4)
Hardware Budget (RTX 5080, 16 GB VRAM)
| Component | VRAM |
|---|---|
| CrossChannelFormer (4M params) | ~16 MB |
| Image ProjectionHead | ~100 MB |
| Text ProjectionHead (0.6M params) | ~4 MB |
| Batch features (512 Γ 9 Γ 5 Γ 1024, fp16) | ~14 GB |
| Similarity matrix + loss computation | ~1 GB |
| Optimizer states (AdamW) | ~200 MB |
| Total estimated | ~15 GB |
Evaluation Metrics
Retrieval Metrics (Primary)
Given an image embedding, retrieve the correct text description (and vice versa).
| Metric | Description |
|---|---|
| mAP@k | Mean average precision at k (k=1, 5, 10) |
| Recall@k | Fraction of queries with correct match in top-k |
| MRR | Mean reciprocal rank of first correct match |
Perturbation Matching Metrics
| Metric | Description |
|---|---|
| NSC (Normalized Similarity to Controls) | How well perturbation embeddings separate from DMSO |
| Perturbation retrieval accuracy | Given a well, can we identify the correct perturbation? |
| Cross-modality matching | Given compound X, retrieve CRISPR knockouts of its target gene |
Embedding Quality Metrics
| Metric | Description |
|---|---|
| Alignment | Mean cosine similarity of matched (image, text) pairs |
| Uniformity | Distribution of embeddings on unit hypersphere |
| Batch effect score | Silhouette score of plate clustering (lower = better correction) |
Implementation Order
Phase 1: Baseline (MVP)
Build the minimum viable training loop with standard CLIP loss.
- CrossChannelFormer β 2-layer transformer, mean-pool aggregation
- Image encoder wrapper β CrossChannelFormer + ProjectionHead + site mean-pool
- Standard InfoNCE loss β symmetric CLIP loss (no soft labels yet)
- Training script β basic loop with AdamW, cosine LR, early stopping
- Evaluation β Recall@1/5/10, mAP on validation set
Goal: Verify the pipeline trains end-to-end and embeddings are meaningful.
Phase 2: Full MorphoCLIP
Add the MorphoCLIP-specific components.
- CWCL loss β soft positive weighting based on perturbation identity
- CWA batch correction β per-plate mean subtraction during training
- Perturbation-aware evaluation β NSC, cross-modality matching
Goal: Match or exceed CellCLIP performance on perturbation retrieval.
Phase 3: Ablations
- LoRA fine-tuning β rank 8β16 on DINOv3, using cached resized tensors
- CrossChannelFormer depth β 1 vs 2 vs 4 layers
- Site aggregation β mean vs attention pooling
- Text granularity β full vs name_target vs name_only prompts
- Split strategy β perturbation vs plate-level held-out
File Plan
src/morphoclip/
models/
cross_channel_former.py # CrossChannelFormer (2-layer transformer)
image_encoder.py # Full image encoder wrapper
training/
__init__.py
losses.py # InfoNCE + CWCL
batch_correction.py # CWA (cross-well alignment)
trainer.py # Training loop, checkpointing, logging
evaluator.py # Retrieval metrics, perturbation matching
scripts/
training/
train.py # CLI entry point
eval.py # CLI entry point for evaluationKey Design Decisions
-
Pre-computed features, not end-to-end β DINOv3 is frozen, so we extract and cache CLS tokens once (~3 GB/plate). Training iterates over
.ptfiles, not raw images. This enables ~3-5 min/epoch on a single GPU. -
Both text and image projection heads are trainable β Unlike CellCLIP which only trains the text side, MorphoCLIP trains both projection heads. The CrossChannelFormer is also trainable.
-
Cached 768-d BERT features β Raw BERT outputs are cached and projected during training. This means the text ProjectionHead can be re-trained without re-running BERT (~10x speedup).
-
Multi-perturbation training β MorphoCLIP trains on compounds AND genetic perturbations (CRISPR + ORF) jointly, unlike CellCLIP which trains only on compounds. This enables cross-modality retrieval.
-
Stratified mixed splits β The
pert_typestrategy distributes all perturbation types (compound, CRISPR, ORF) across train/val/test by hashingbroad_sampleIDs. This ensures the model sees every modality during training and eval loss is comparable to train loss. Wells sharing the same perturbation always land in the same split. Uses only local metadata β no external submodules needed. -
Well-level samples β Each training sample is a well (multiple sites), not a single site. Site embeddings are mean-pooled. This reduces noise and matches the biological unit of interest.
References
- CellCLIP (Ye et al., 2024) β Text-supervised contrastive learning for Cell Painting, compounds only
- CWCL β Continuously Weighted Contrastive Learning for soft-positive handling
- CWA β Cross-Well Alignment for batch effect correction
- DINOv3 β Self-supervised ViT for image feature extraction
- CPJUMP1 β Joint Undertaking in Morphological Profiling pilot dataset (56 plates)