Medical LLM from Scratch
Built a large language model from first principles and fine-tuned it on public medical conversational and MCQ datasets for clinical question answering.
- ▸Implemented transformer architecture from scratch including multi-head attention, positional encoding, and layer normalization
- ▸Fine-tuned on medical conversational datasets (MedDialog, HealthCareMagic) and MCQ benchmarks (MedQA, PubMedQA)
- ▸Achieved competitive performance on medical QA tasks compared to baseline models
- ▸Explored instruction tuning and RLHF alignment techniques for medical domain safety
Overview
This project was part of an advanced NLP course at EPFL where the goal was to build a language model entirely from scratch — no pre-trained weights, no existing model architectures — and apply it to a practical domain. I chose the medical domain because of its high stakes and the unique challenges it presents for language models: factual accuracy, clinical reasoning, and patient safety.
What I Built
The project consisted of three major phases:
1. Transformer Architecture Implementation
I implemented the full transformer decoder architecture from scratch in PyTorch, including:
- Multi-head self-attention with causal masking
- Rotary positional embeddings (RoPE)
- RMSNorm (instead of LayerNorm, following LLaMA's approach)
- SwiGLU activation functions
- KV-cache for efficient inference
2. Pre-training on Medical Corpora
The model was pre-trained on a curated medical text corpus including PubMed abstracts, clinical notes (de-identified), and medical textbook excerpts. Training involved:
- BPE tokenizer trained from scratch on the medical corpus
- Mixed-precision training (bfloat16)
- Gradient accumulation for effective large batch sizes
- Learning rate scheduling with warmup and cosine decay
3. Fine-tuning for Clinical QA
After pre-training, the model was fine-tuned on:
- MedDialog and HealthCareMagic for conversational medical QA
- MedQA and PubMedQA for multiple-choice medical reasoning
- Custom instruction-following format for safe medical responses
Technical Details
The model architecture followed the LLaMA design pattern but at a smaller scale suitable for academic compute budgets:
- Parameters: ~125M (comparable to GPT-2 small)
- Context length: 2048 tokens
- Attention heads: 12
- Hidden dimension: 768
- Layers: 12
Key technical decisions:
- RoPE over learned positional embeddings: Better length generalization and computational efficiency
- SwiGLU over GELU: Marginally better performance at same parameter count
- RMSNorm: Simpler and slightly faster than LayerNorm, equally effective
Challenges & Tradeoffs
The biggest challenge was working within limited compute constraints (single A100 for ~48h). This required careful decisions:
- Model size vs. data quality: I opted for a smaller model trained on higher-quality, curated medical data rather than a larger model on noisy web data
- Evaluation rigor: Medical QA evaluation is notoriously tricky — models can score well on MCQs by pattern matching without genuine understanding
- Safety considerations: Ensuring the model doesn't generate harmful medical advice required careful instruction tuning and output filtering
Results
- Achieved 62.3% accuracy on MedQA (5-shot), competitive with larger pre-trained models fine-tuned on the same data
- 78.1% accuracy on PubMedQA (binary classification)
- Qualitative evaluation showed coherent, contextually appropriate medical responses with appropriate hedging and referral suggestions
What I Learned
- Deep understanding of transformer internals — implementing attention from scratch forces understanding of every gradient flow
- The importance of data quality over quantity, especially in specialized domains
- How RLHF and instruction tuning fundamentally change model behavior beyond just loss reduction
- Practical experience with distributed training patterns, even at smaller scale