Skip to Content
PipelineTraining Pipeline

MorphoCLIP Training Pipeline

πŸ’‘

In Plain English: MorphoCLIP learns to match microscopy images with text descriptions of biological treatments. It takes cell images processed through a vision AI (DINOv3) and text descriptions processed through a language AI (ModernBERT), then trains small β€œadapter” networks to place matching image-text pairs near each other in a shared mathematical space. See the Glossary for term definitions.

Architecture

MorphoCLIP Architecture

Status

ComponentStatus
DINOv3 feature extractionDone
Text encoder + projection headDone
Dataset + collate + splitsDone
Text embedding caching (768-d)Done
CrossChannelFormerDone
Image projection headDone
InfoNCE lossDone
CWCL loss (soft labels)Done
CWA batch correctionDone
Training loop + checkpointingDone
Retrieval metrics (R@k, MRR)Done
Perturbation matching metricsTo build
LoRA fine-tuningTo build

Architecture Overview

IMAGE BRANCH TEXT BRANCH ───────────── ─────────── 5 fluorescence channels/site PerturbationInfo β”‚ β”‚ DINOv3 ViT-L/16 (frozen) PromptBuilder (verbose templates) β”‚ β”‚ CLS tokens (5, 1024)/site BioClinical ModernBERT (frozen) β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” CLS token (768-d) ← cached β”‚ CrossChannelFormer β”‚ β”‚ β”‚ (1-layer xformer) β”‚ ProjectionHead (768β†’512) β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ (1, 1024)/site 512-d L2-normalized β”‚ Image ProjectionHead (1024β†’512) β”‚ 512-d L2-normalized β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ CWCL Contrastive Loss + CWA batch correction

Components to Build

1. CrossChannelFormer (src/morphoclip/models/cross_channel_former.py)

Aggregates 5 per-channel CLS tokens into a single image representation per site.

Why needed: DINOv3 processes each of the 5 fluorescence channels independently (replicated to pseudo-RGB). The CrossChannelFormer learns cross-channel interactions β€” e.g., how mitochondrial morphology relates to actin cytoskeleton structure β€” producing one unified 1024-d token per site.

Input: (batch, 5, 1024) β€” 5 channel CLS tokens Output: (batch, 1024) β€” single aggregated representation

Design:

  • 1-layer transformer encoder (nn.TransformerEncoder, default; configurable via ccf_layers)
  • 5 learnable channel-type embeddings (one per: Mitochondria, Actin, Golgi, ER, DNA) added to input tokens
  • Aggregation: mean-pool over the 5 output tokens (or use a learnable CLS query token)
  • 4 attention heads (configurable via ccf_heads)
  • ~4M parameters

Alternatives considered:

  • Simple mean pooling (no learnable cross-channel interaction β€” baseline ablation)
  • Single attention layer (fewer params but may underfit)

2. Image Projection Head (src/morphoclip/models/image_projection_head.py)

Maps aggregated image representation to shared contrastive space.

Input: (batch, 1024) β€” CrossChannelFormer output Output: (batch, 512) β€” L2-normalized embedding

Design: Reuse existing ProjectionHead class with input_dim=1024:

image_proj = ProjectionHead(input_dim=1024, hidden_dim=512, output_dim=512)

3. Site Aggregation

Each well has multiple imaging sites (up to 9). Sites must be aggregated into a single well-level embedding before loss computation.

Options (in order of simplicity):

  1. Mean pooling β€” average site embeddings (use site_mask from collate_fn to exclude padding)
  2. Attention pooling β€” learnable weighted average over sites
  3. Max pooling β€” element-wise max

Recommendation: Start with masked mean pooling. It’s simple, proven, and the site-level variation within a well is small.

4. CWCL Loss (src/morphoclip/training/losses.py)

Continuously Weighted Contrastive Loss β€” handles soft positives (multiple wells with the same perturbation).

Standard CLIP uses hard positive pairs (1 image ↔ 1 text). In Cell Painting, many wells share the same perturbation, creating multiple valid positives per text description. CWCL assigns continuous similarity weights instead of binary labels.

Input: image_embeds (B, 512), text_embeds (B, 512) Output: scalar loss

Algorithm:

  1. Compute cosine similarity matrix: S = image_embeds @ text_embeds.T β†’ (B, B)

  2. Build soft label matrix W where W[i,j] reflects how similar perturbation_i is to perturbation_j:

    • Same broad_sample β†’ weight 1.0
    • Same target gene (different compound) β†’ weight 0.5–0.8
    • Different perturbation β†’ weight 0.0
  3. Temperature-scaled cross-entropy with soft labels:

    loss_i2t = -sum(W[i,:] * log_softmax(S[i,:] / tau)) loss_t2i = -sum(W[:,j] * log_softmax(S[:,j] / tau)) loss = (loss_i2t + loss_t2i) / 2
  4. Learnable temperature tau (initialized to 0.07, log-parameterized)

Fallback: Start with standard symmetric InfoNCE (CLIP loss) as a baseline, then add soft weighting.

5. CWA Batch Correction (src/morphoclip/training/batch_correction.py)

Cross-Well Alignment β€” removes plate-specific technical artifacts from image embeddings during training.

Cell Painting images have significant batch effects (plate-to-plate variation from staining, imaging conditions). Without correction, the model learns to cluster by plate rather than by perturbation.

Algorithm:

  1. Within each training batch, group image embeddings by plate
  2. Compute per-plate mean embedding: mu_p = mean(embeds[plate == p])
  3. Subtract plate mean: embeds_corrected = embeds - mu_p
  4. Re-normalize to unit sphere: embeds_corrected = L2_norm(embeds_corrected)

When applied: After the image projection head, before loss computation. Text embeddings are not corrected (no batch effect in text).

Requires: Batch sampler that ensures each batch contains wells from multiple plates (see Data Loading below).

6. Full Image Encoder (src/morphoclip/models/image_encoder.py)

Wraps the full image-side pipeline:

class MorphoCLIPImageEncoder(nn.Module): def __init__(self): self.cross_channel_former = CrossChannelFormer(dim=1024, depth=1, heads=4) self.projection_head = ProjectionHead(input_dim=1024, output_dim=512) def forward(self, features, site_mask): """ Args: features: (B, max_sites, 5, 1024) β€” padded DINOv3 CLS tokens site_mask: (B, max_sites) β€” True for real sites Returns: (B, 512) β€” L2-normalized well embeddings """ B, S, C, D = features.shape # Merge batch and site dims x = features.view(B * S, C, D) # (B*S, 5, 1024) x = self.cross_channel_former(x) # (B*S, 1024) x = x.view(B, S, -1) # (B, S, 1024) # Masked mean pooling over sites mask = site_mask.unsqueeze(-1).float() # (B, S, 1) x = (x * mask).sum(dim=1) / mask.sum(dim=1) # (B, 1024) x = self.projection_head(x) # (B, 512) L2-normalized return x

Training Loop Design

Data Loading

# 1. Build dataset metadata = MetadataIndex.from_config(config_path) dataset = MorphoCLIPDataset( feature_dir=Path("data/features"), metadata=metadata, plates=train_plates, mode="features", text_level="full", exclude_controls=True, ) # 2. Split (stratified by broad_sample across all pert types) train_set, val_set, test_set = create_splits( dataset, strategy="pert_type", val_fraction=0.1, # 10% val, 10% test, 80% train ) # 3. DataLoader with custom collate train_loader = DataLoader(train_set, batch_size=512, shuffle=True, collate_fn=collate_fn, num_workers=0)

Split strategy: The pert_type strategy uses only local metadata (data/metadata/). Each unique broad_sample is deterministically hashed to train/val/test, so all wells sharing the same perturbation land in the same split. Compounds, CRISPR, and ORF are all represented in every split, ensuring the model trains and evaluates on all perturbation modalities. Benchmark-aligned strategies (cpjump1_official_*, cellclip_cpjump_style) are available in benchmark.splits for comparison with published baselines.

Note: CWA is disabled by default (use_cwa: false in base config). When enabled, each batch needs wells from multiple plates. With 51 plates and batch_size=512, random shuffling provides sufficient plate diversity.

Data loading: With preload: true (default), all features are loaded into RAM at startup with num_workers: 0. Set preload: false and increase num_workers if RAM is constrained.

Text Embedding Strategy

Text embeddings are pre-computed and cached (768-d raw BERT features). During training:

  1. Load cached text embeddings at startup using load_cached_text_features()
  2. Look up cached 768-d embedding by broad_sample ID
  3. Pass through the (trainable) text ProjectionHead to get 512-d
  4. This avoids running BERT forward pass every epoch (~10x faster)

Training Step Pseudocode

for batch in train_loader: # Image branch image_embeds = image_encoder(batch["features"], batch["site_mask"]) # (B, 512) # Text branch (from cache) text_raw = lookup_cached_embeddings(batch["pert_info"]) # (B, 768) text_embeds = text_projection_head(text_raw) # (B, 512) # Batch correction (CWA) image_embeds = cross_well_align(image_embeds, batch["plates"]) # Loss loss = cwcl_loss(image_embeds, text_embeds, batch["pert_info"]) # Backward optimizer.zero_grad() loss.backward() optimizer.step()

Hyperparameters

ParameterValueNotes
Batch size512More negatives for contrastive learning
Learning rate1e-4AdamW, weight_decay=0.2
LR scheduleCosine annealing with warmup100 warmup steps
Temperature (tau)Learnable (LogitScaleModule)Log-parameterized, clamped to ln(100)
Epochs100Early stopping on val loss
Gradient clippingmax_norm=1.0Prevent training instability
Mixed precisionfp16 (torch.amp)Halves VRAM, faster matmuls
Projection dropout0.3Stronger regularization in projection heads
CWADisabled by defaultEnable with use_cwa: true

Config variants:

  • configs/train/base.yaml β€” default config (CCF aggregation, batch_size=512)
  • configs/train/mean_pool.yaml β€” replaces CCF with simple mean pooling (ablation)
  • configs/train/ddp.yaml β€” multi-GPU DDP (batch_size=128/GPU, num_workers=4)

Hardware Budget (RTX 5080, 16 GB VRAM)

ComponentVRAM
CrossChannelFormer (4M params)~16 MB
Image ProjectionHead~100 MB
Text ProjectionHead (0.6M params)~4 MB
Batch features (512 Γ— 9 Γ— 5 Γ— 1024, fp16)~14 GB
Similarity matrix + loss computation~1 GB
Optimizer states (AdamW)~200 MB
Total estimated~15 GB

Evaluation Metrics

Retrieval Metrics (Primary)

Given an image embedding, retrieve the correct text description (and vice versa).

MetricDescription
mAP@kMean average precision at k (k=1, 5, 10)
Recall@kFraction of queries with correct match in top-k
MRRMean reciprocal rank of first correct match

Perturbation Matching Metrics

MetricDescription
NSC (Normalized Similarity to Controls)How well perturbation embeddings separate from DMSO
Perturbation retrieval accuracyGiven a well, can we identify the correct perturbation?
Cross-modality matchingGiven compound X, retrieve CRISPR knockouts of its target gene

Embedding Quality Metrics

MetricDescription
AlignmentMean cosine similarity of matched (image, text) pairs
UniformityDistribution of embeddings on unit hypersphere
Batch effect scoreSilhouette score of plate clustering (lower = better correction)

Implementation Order

Phase 1: Baseline (MVP)

Build the minimum viable training loop with standard CLIP loss.

  1. CrossChannelFormer β€” 2-layer transformer, mean-pool aggregation
  2. Image encoder wrapper β€” CrossChannelFormer + ProjectionHead + site mean-pool
  3. Standard InfoNCE loss β€” symmetric CLIP loss (no soft labels yet)
  4. Training script β€” basic loop with AdamW, cosine LR, early stopping
  5. Evaluation β€” Recall@1/5/10, mAP on validation set

Goal: Verify the pipeline trains end-to-end and embeddings are meaningful.

Phase 2: Full MorphoCLIP

Add the MorphoCLIP-specific components.

  1. CWCL loss β€” soft positive weighting based on perturbation identity
  2. CWA batch correction β€” per-plate mean subtraction during training
  3. Perturbation-aware evaluation β€” NSC, cross-modality matching

Goal: Match or exceed CellCLIP performance on perturbation retrieval.

Phase 3: Ablations

  1. LoRA fine-tuning β€” rank 8–16 on DINOv3, using cached resized tensors
  2. CrossChannelFormer depth β€” 1 vs 2 vs 4 layers
  3. Site aggregation β€” mean vs attention pooling
  4. Text granularity β€” full vs name_target vs name_only prompts
  5. Split strategy β€” perturbation vs plate-level held-out

File Plan

src/morphoclip/ models/ cross_channel_former.py # CrossChannelFormer (2-layer transformer) image_encoder.py # Full image encoder wrapper training/ __init__.py losses.py # InfoNCE + CWCL batch_correction.py # CWA (cross-well alignment) trainer.py # Training loop, checkpointing, logging evaluator.py # Retrieval metrics, perturbation matching scripts/ training/ train.py # CLI entry point eval.py # CLI entry point for evaluation

Key Design Decisions

  1. Pre-computed features, not end-to-end β€” DINOv3 is frozen, so we extract and cache CLS tokens once (~3 GB/plate). Training iterates over .pt files, not raw images. This enables ~3-5 min/epoch on a single GPU.

  2. Both text and image projection heads are trainable β€” Unlike CellCLIP which only trains the text side, MorphoCLIP trains both projection heads. The CrossChannelFormer is also trainable.

  3. Cached 768-d BERT features β€” Raw BERT outputs are cached and projected during training. This means the text ProjectionHead can be re-trained without re-running BERT (~10x speedup).

  4. Multi-perturbation training β€” MorphoCLIP trains on compounds AND genetic perturbations (CRISPR + ORF) jointly, unlike CellCLIP which trains only on compounds. This enables cross-modality retrieval.

  5. Stratified mixed splits β€” The pert_type strategy distributes all perturbation types (compound, CRISPR, ORF) across train/val/test by hashing broad_sample IDs. This ensures the model sees every modality during training and eval loss is comparable to train loss. Wells sharing the same perturbation always land in the same split. Uses only local metadata β€” no external submodules needed.

  6. Well-level samples β€” Each training sample is a well (multiple sites), not a single site. Site embeddings are mean-pooled. This reduces noise and matches the biological unit of interest.


References

  • CellCLIP (Ye et al., 2024) β€” Text-supervised contrastive learning for Cell Painting, compounds only
  • CWCL β€” Continuously Weighted Contrastive Learning for soft-positive handling
  • CWA β€” Cross-Well Alignment for batch effect correction
  • DINOv3 β€” Self-supervised ViT for image feature extraction
  • CPJUMP1 β€” Joint Undertaking in Morphological Profiling pilot dataset (56 plates)
Last updated on