All Projects
EPFL··school·coursework

Medical LLM from Scratch

Built a large language model from first principles and fine-tuned it on public medical conversational and MCQ datasets for clinical question answering.

mlllmnlp
  • Implemented transformer architecture from scratch including multi-head attention, positional encoding, and layer normalization
  • Fine-tuned on medical conversational datasets (MedDialog, HealthCareMagic) and MCQ benchmarks (MedQA, PubMedQA)
  • Achieved competitive performance on medical QA tasks compared to baseline models
  • Explored instruction tuning and RLHF alignment techniques for medical domain safety
Stack
PythonPyTorchHugging FaceWeights & Biases
RoleIndividual project
Team1 people

Overview

This project was part of an advanced NLP course at EPFL where the goal was to build a language model entirely from scratch — no pre-trained weights, no existing model architectures — and apply it to a practical domain. I chose the medical domain because of its high stakes and the unique challenges it presents for language models: factual accuracy, clinical reasoning, and patient safety.

What I Built

The project consisted of three major phases:

1. Transformer Architecture Implementation

I implemented the full transformer decoder architecture from scratch in PyTorch, including:

  • Multi-head self-attention with causal masking
  • Rotary positional embeddings (RoPE)
  • RMSNorm (instead of LayerNorm, following LLaMA's approach)
  • SwiGLU activation functions
  • KV-cache for efficient inference

2. Pre-training on Medical Corpora

The model was pre-trained on a curated medical text corpus including PubMed abstracts, clinical notes (de-identified), and medical textbook excerpts. Training involved:

  • BPE tokenizer trained from scratch on the medical corpus
  • Mixed-precision training (bfloat16)
  • Gradient accumulation for effective large batch sizes
  • Learning rate scheduling with warmup and cosine decay

3. Fine-tuning for Clinical QA

After pre-training, the model was fine-tuned on:

  • MedDialog and HealthCareMagic for conversational medical QA
  • MedQA and PubMedQA for multiple-choice medical reasoning
  • Custom instruction-following format for safe medical responses

Technical Details

The model architecture followed the LLaMA design pattern but at a smaller scale suitable for academic compute budgets:

  • Parameters: ~125M (comparable to GPT-2 small)
  • Context length: 2048 tokens
  • Attention heads: 12
  • Hidden dimension: 768
  • Layers: 12

Key technical decisions:

  • RoPE over learned positional embeddings: Better length generalization and computational efficiency
  • SwiGLU over GELU: Marginally better performance at same parameter count
  • RMSNorm: Simpler and slightly faster than LayerNorm, equally effective

Challenges & Tradeoffs

The biggest challenge was working within limited compute constraints (single A100 for ~48h). This required careful decisions:

  • Model size vs. data quality: I opted for a smaller model trained on higher-quality, curated medical data rather than a larger model on noisy web data
  • Evaluation rigor: Medical QA evaluation is notoriously tricky — models can score well on MCQs by pattern matching without genuine understanding
  • Safety considerations: Ensuring the model doesn't generate harmful medical advice required careful instruction tuning and output filtering

Results

  • Achieved 62.3% accuracy on MedQA (5-shot), competitive with larger pre-trained models fine-tuned on the same data
  • 78.1% accuracy on PubMedQA (binary classification)
  • Qualitative evaluation showed coherent, contextually appropriate medical responses with appropriate hedging and referral suggestions

What I Learned

  • Deep understanding of transformer internals — implementing attention from scratch forces understanding of every gradient flow
  • The importance of data quality over quantity, especially in specialized domains
  • How RLHF and instruction tuning fundamentally change model behavior beyond just loss reduction
  • Practical experience with distributed training patterns, even at smaller scale