Building diffusionGPT: My Journey into Discrete Diffusion
Most people know diffusion models for generating images, but I wanted to see if I could make them work for text—specifically, to build a conversational AI that doesn't just predict the next token left-to-right, but iteratively "refines" its thoughts.
This is the story of diffusionGPT, the technical choices I made, and the unexpected walls I hit along the way.
Why Discrete Diffusion?
Unlike traditional autoregressive models (like Llama or GPT) that generate text sequentially, diffusionGPT uses a discrete diffusion process. It starts with a noisy, masked sequence and simultaneously denoises it.
This approach unlocks something fascinating: Parallel Decoding. Instead of waiting for token t to generate token t+1, the model generates and refines multiple tokens simultaneously. This enables the model to generate text in fewer inference passes than an autoregressive model would.
Under the Hood: The Algorithm
I based my core methodology on Masked Diffusion Language Models (MDLM). The process is a game of corruption and recovery:
- Training Time: We sample a timestep t ~ Uniform(0, 1), mask that proportion of tokens, and force the model to predict the originals.
- Inference Time: We start with a fully masked sequence. The model predicts the full sequence and we then mask a fraction of the tokens again. We repeat this process a number of steps, without masking previously unmasked tokens.
To further improve this method, I followed the implementation of Seed Diffusion by ByteDance. They allow the model to edit visible tokens to correct any mistakes that the model had made.
The Architecture: Why I Chose ModernBERT
Instead of grabbing a standard decoder-only model, I made a specific choice to use ModernBERT (by Answer.AI) as the backbone.
Why? Diffusion models require bidirectional attention. To "refine" a token in the middle of a sentence, the model needs to see the context both before and after it. Encoder-only architectures are naturally suited for this. In addition, the training task is very similar to the one used to train BERT models: predicting masked tokens.
The Engineering Hurdles
Standard Causal Language Modeling (CLM) trainers just didn't cut it for this. I had to build a custom training loop from the ground up. Well, almost. I wanted to use the Hugging Face ecosystem to learn the libraries rather than doing everything from scratch. I had already done several projects of implementing things from scratch, it was time for me to learn how to use the Hugging Face ecosystem effectively.
1. The Missing Pipeline
Hugging Face Transformers has no standard generate() method for diffusion language models. I had to build a custom TextDiffusionPipeline to handle:
- Ancestral Sampling: Iteratively sampling from the model's logits.
- Confidence-Based Unmasking: Logic to selectively unmask only the most "confident" tokens at each step.
Technical Note: I implemented very basic confidence-based unmasking, assuming it would be smarter. However, in my experiments, random unmasking often worked better. Confidence-based unmasking led to the model entering a never ending loop of repeating itself. It is on my to-do list to investigate how this method is currently done.
2. Implementing Seed Diffusion
To support editing, I implemented a two-stage curriculum from Seed Diffusion:
- Stage 1 (0-80%): Standard MDLM objective (pure masking).
- Stage 2 (80-100%): I introduce "gold" token corruption. This teaches the model to refine existing incorrect tokens rather than just filling blanks. It also introduces the
<|delete|>token, enabling the model to learn token removal (note: delete corruption is not yet implemented for SFT).
Lesson Learned: In my experiments, I found that having the edit stage active for the entire training length actually yielded better results. I suspect this is because it uses all tokens for loss calculation, not just the masked ones. Because it is trained on little data and a small sequence length, only calculating the loss on masked tokens is insufficient for the model to learn language modeling.
The "It Somewhat Works" Phase
I started small with the TinyStories dataset. It worked! It generated reasonable stories. With this model, I tested the confidence-based unmasking approach, the semi-autoregressive inference to generate indefinitely (see video below).
Then, I scaled up to FineWeb. It didn't go well. The model struggled to generate coherent text. My first thought was that I had messed something up. I returned to debugging TinyStories, and everything seemed fine. It was just a matter of hyperparameters it seems. In hindsight, I should have performed a hyperparameter sweep instead of training fully just to "see if it works". I decided to continue to see if the SFT worked to at least debug on this model.
The SFT Struggle
My experiment with Supervised Fine-Tuning (SFT) was a rollercoaster. I tested on smol-smoltalk and, surprisingly, it worked somewhat. As the model does not have a causal mask, I had to split the conversations so that it only sees the chat up to the current assistant response.
First, I trained on the full conversation, not just the assistant response. The model learned to generate coherent answers, it could also be used with the semi-autoregressive inference mode. However, it got confused many times and asked as a user instead of answering as an assistant.
Then, I switched to training only on the assistant answers. That seemed to work better, so I added the everyday-conversations-llama3.1-2k and Nemotron-Instruction-Following-Chat-v1 datasets. When finished training, the model could answer questions reliably.
The <|im_end|> Problem:
Now I finally had a chatbot, but there was a catch. The model always generates the <|im_end|> token at the very end of every generation, regardless of the text length. This broke the semi-autoregressive capabilities I was excited about, though it did allow for controlling exact answer length. During the SFT phase, the data samples always have the <|im_end|> token at the end, therefore the model always "sees" that token at the end of every generation.
I still have to figure out how to solve this issue. Having the semi-autoregressive capabilities would allow for "thinking" capabilities, making it also more suitable for RL; things I would like to add in the future.
Roadmap: What's Next?
This is just the beginning. There are many things to polish and improve of the current implementation. The roadmap for diffusionGPT includes:
- Evaluation Suite: Implement proper benchmarks (beyond loss) to evaluate generation quality quantitatively.
- Fixing Semi-Autoregressive SFT: Modifying masking so the model isn't "forced" to end every sequence with an end token.
- Confidence-based sampling: Improve the current basic implementation of confidence-based sampling.
- "Thinking" Capability: Exploring Chain of Thought (CoT) reasoning.
- Tool use: Adding tool use/function calling capabilities.
- Reinforcement Learning: Using RL to improve its math capabilities.
- Deletion Corruption on SFT: Teaching the model to explicitly remove bad tokens, not just replace them.
The code is open source, and you can check it out on GitHub. The model can be found on Hugging Face.