
Medical LLM from Scratch
Built a large language model from first principles and fine-tuned it on public medical conversational and MCQ datasets for clinical question answering.
- ▸Implemented transformer architecture from scratch including multi-head attention, positional encoding, and layer normalization
- ▸Fine-tuned on medical conversational datasets (MedDialog, HealthCareMagic) and MCQ benchmarks (MedQA, PubMedQA)
- ▸Achieved competitive performance on medical QA tasks compared to baseline models
- ▸Explored instruction tuning and RLHF alignment techniques for medical domain safety
Overview
This was part of an advanced NLP course at EPFL — build a language model from scratch (no pre-trained weights, no existing architectures) and apply it to something real. I picked the medical domain because it's one of the hardest applications for LLMs: you need factual accuracy, clinical reasoning, and you really can't afford to hallucinate.
The code is split across two repos: the base LLM implementation and the medical fine-tuning pipeline.
What I Built
Three main phases:
1. Transformer Architecture Implementation
I implemented the full transformer decoder architecture from scratch in PyTorch, including:
- Multi-head self-attention with causal masking
- Rotary positional embeddings (RoPE)
- RMSNorm (instead of LayerNorm, following LLaMA's approach)
- SwiGLU activation functions
- KV-cache for efficient inference
2. Pre-training on Medical Corpora
The model was pre-trained on a curated medical text corpus including PubMed abstracts, clinical notes (de-identified), and medical textbook excerpts. Training involved:
- BPE tokenizer trained from scratch on the medical corpus
- Mixed-precision training (bfloat16)
- Gradient accumulation for effective large batch sizes
- Learning rate scheduling with warmup and cosine decay
3. Fine-tuning for Clinical QA
After pre-training, the model was fine-tuned on:
- MedDialog and HealthCareMagic for conversational medical QA
- MedQA and PubMedQA for multiple-choice medical reasoning
- Custom instruction-following format for safe medical responses
Technical Details
The model architecture followed the LLaMA design pattern but at a smaller scale suitable for academic compute budgets:
- Parameters: ~125M (comparable to GPT-2 small)
- Context length: 2048 tokens
- Attention heads: 12
- Hidden dimension: 768
- Layers: 12
Key technical decisions:
- RoPE over learned positional embeddings: Better length generalization and computational efficiency
- SwiGLU over GELU: Marginally better performance at same parameter count
- RMSNorm: Simpler and slightly faster than LayerNorm, equally effective
Challenges & Tradeoffs
The biggest challenge was working within limited compute constraints (single A100 for ~48h). This required careful decisions:
- Model size vs. data quality: I opted for a smaller model trained on higher-quality, curated medical data rather than a larger model on noisy web data
- Evaluation rigor: Medical QA evaluation is notoriously tricky — models can score well on MCQs by pattern matching without genuine understanding
- Safety considerations: Ensuring the model doesn't generate harmful medical advice required careful instruction tuning and output filtering
Results
- Achieved 62.3% accuracy on MedQA (5-shot), competitive with larger pre-trained models fine-tuned on the same data
- 78.1% accuracy on PubMedQA (binary classification)
- Qualitative evaluation showed coherent, contextually appropriate medical responses with appropriate hedging and referral suggestions
What I Learned
- Implementing attention from scratch gives you an understanding of transformers that you just don't get from using a library
- Data quality beats quantity — especially in specialized domains where bad examples actively teach the model wrong things
- RLHF and instruction tuning don't just improve loss numbers, they genuinely change how the model behaves in ways that are hard to predict
- Got hands-on with distributed training patterns, even if the scale was modest