jif: Jax text diffusion
Code: https://github.com/neverix/jif
This blog post is about some work I did in the summer of 2024. Since text diffusion models have become a hot topic with DeepMind's Gemini Diffusion and HKU NLP's Dream, I thought it would be a good idea to finish the draft for this now.
Introduction
Diffusion is currently the most popular method for training neural networks to generate images. It works by iteratively removing Gaussian noise from an image until it has the same distribution as the data. Why this works so well is a mystery, but it's likely because removing noise is, for most image distributions, the same objective as going from more to less blurry pictures autoregressively. The denoising process being iterative gives diffusion some nice properties:
- Score1 functions can be mixed and matched: you can mix noise for different text conditions to create images matching both (also, diffusion can be finetuned for exact MCMC sampling) and use gradients from an image classifier or encoder as a score for an unconditional model. Gaussian diffusion is flexible enough to be combined with a noisy linear observation without training for zero-shot inpainting, deblurring and more.
- The deterministic version of the diffusion process can be reversed by simply integrating the sampling process in the reverse direction. This is useful for finding noise latents that can generate images with a different condition and for finding the exact log likelihood of any given image.2
There have been many attempts to apply diffusion to text. Some of them formulate text diffusion as refining a distribution on a set of correlated discrete variables, while others model discrete tokens being randomly corrupted. I will ignore all of them except for the line of work this project is an implementation of.
SEDD
SEDD by (Lou et al. 2024) is an early successful implementation of text diffusion at GPT-2 scales. It doesn't match autoregressive modelling in (upper bounds of) NLL, but it can generate coherent text and even has induction-like heads 3. Its forward diffusion process is a continuous Markov chain where the entire state of the sequence is a single giant discrete variable:
Of course, they factorize the matrix and probability density on the token dimension:
It can be reversed as:
...Which is similar to the reversal of the forward process of Gaussian diffusion models based on Tweedie's formula, but instead of the score we have ratios . The goal of training is then learning these ratios, or a factorization. The objective in Lou et al. is score entropy. Given predicted ratios :
where is a normalizing constant function that ensures that .
They prove some properties, like that minimizing it requires learning the correct ratios and that it scales down the gradients by the predicted ratios, preventing blow-ups. There is also a connection to Bregman divergences ( for convex ):
...Which is exactly the formula for SE, if we substitute in K and take into account the constraint that in the sum.
SEDD considers two choices for the matrix : uniform and absorbing. The uniform matrix slowly decays into a uniform distribution over tokens, while with the absorbing matrix there is a small chance of each token jumping into a [MASK] state at each timestep. It is similar to BERT masking, but extended to all masking ratios and with a fancy loss instead of crossentropy. It turns out that absorbing SEDD works better - even better than GPT2 on Wikitext2 and PTB (though with a lot more training).
It turns out that SE is equivalent to crossentropy in the absorbing case up to a constant. For a single token , if the predicted ratio is , score entroy is
...Where is the total magnitude of the noise. The probability a token is equal to noise at any given point is . Then, and , which is exactly the ratio of the probabilities of the token being masked or unmasked. Note that this ratio is only nonzero when is the ground truth token; if it was not, the probability of transitioning to it would be zero. Continuing,
This objective is much simpler. The objective of the network is then to maximize and minimize . At this point, intuitively, this looks similar to crossentropy/negative log-likelihood, but with unnormalized logits: the network needs to allocate the score to the tokens it thinks are more likely and can't increase the total by too much. In SEDD, the prediction is actually the log-space score, so the objective becomes:
Given that logsumexp is one-dimensional, we can separate out the crossentropy subproblem and worry about optimizing it later, with the remaining objective purely in terms of it:
...Which is exactly the crossentropy loss, plus a log-sum-exp term. The log-sum-exp is not optimized for during traditional NLL training, so this objective would be satisfied by a network minimizing crossentropy.
Notice that the score of the MASK token is not optimized for in this setup. If one inspects a trained SEDD model, the logit weights and biases corresponding to it are all zero, which would make sense given it received updates only from weight decay.
Still, SEDD was the first text diffusion model with a simple formulation and comparable performance to GPT-2. I initially based my implementation on it, the diffusion transformer being a one-to-one port of the SEDD code.
MDLM
MDLM (Sahoo et al. 2024) is the implementation of text diffusion that is used in most of the recent text diffusion models. It is a masked denoising language model with the same schedule as in SEDD, but with a simplified objective (log-likelihood, as written above).
Once MDLM came out, I ported the diffusion process over to my incomplete codebase. It showed much better results on toy tasks, and I removed SEDD.
jif
The following is a description of the codebase.
Data & Setup
The dataset is TinyStories-1M. The tokenizer is GPT-2 trimmed to the first 32k tokens to decrease memory usage for logits and embeddings. The models are trained on context windows of 128 tokens with a batch size of 256. The masking ratio is uniformly distributed between 0 and 1 with MDLM's low-discrepancy sampling (using a linspace
to set timesteps instead of an RNG).
When sampling, we use 500 steps. If a network is not timestep-conditioned (and is instead conditioned on the number of unmasked tokens), it is possible to cache the logits from the previous timestep and reuse them if no token was unmasked, as described in MDLM.
Model & Sharding
The model is a diffusion transformer as described in SEDD (source code). The model code is implemented in Penzai. The network is declared largely declaratively using Penzai's pre-made components. We use Penzai's side-inputs system to pass DiT conditioning along the module tree without adding it as an argument to each module. The Penzai implementation also means all axes in the models' parameters are named and can be sharded automatically. We shard along the "neurons"
, "kv_heads"
and "vocabulary"
named axes, which correspond to the inner dimension of the transformer's MLP, attention heads and vocabulary. This means we use TP for MLP and KV heads and let Jax figure out FSDP for the vocabulary.
Note: in jax 0.6, Penzai's default embedding layer was very slow. This is because it internally runs something like vmap(lambda i: x[i])
, which makes Jax use scatter-add
for the backward pass, which can't be optimized very well on TPUs. I replaced the vmap-of-lookup with a single indexing operation and reached around 15% MFU.
Training
We optimize the model with AdamW. We tried Muon and Schedule-Free Adam without improvement. We keep an exponential moving average of the weights with a decay of 0.995 like in MDLM. The model generates plausible text on a v4-8 within a minute of compilation.
Future work
Normalization
This project was started with the goal of figuring out a MuP-normalized architecture for diffusion models. Specifically, I wanted to implement Modula, an architecture and optimizer combination that gives the network a Lipschitz constant of 1 at initialization and ensures that updates cannot increase the loss by more than a multiple of a specially tuned norm of the update. I ended up working on other things, but the repo contains a script for testing learning rate transfer, and Penzai is well-suited for implementing Modula on top.
Statefulness
Text diffusion has an inherent disadvantage compared to autoregression: it uses compute in the number of tokens, having to do a forward pass through all tokens to sample a single token. This could turn into an advantage if it can process the context more deeply compared to autoregression, where each token's embedding has only previously sampled tokens as context. However, there is no way for any diffusion step to receive information from previous steps other than through sampled tokens, so this is difficult to do. To fix this, we could follow Bit Diffusion (Chen et al. 2022) in self-conditioning the model on detached-gradient predictions of the previous step. Still, I expect any such method to be tricky to implement and prone to instability, like any recurrent model.
Training objective
Beyond Autoregression: Fast LLMs via Self-Distillation Through Time (Deschenaux & Gulcehre, 2024) introduces a simple method for distilling discrete diffusion models for sampling with fewer timesteps that scales to 1B-scale models. It is still in question if this scales up well and preserves the diversity of generations, but it should be possible to implement in jif
.
PGM improves on the masking objective for diffusion model training by replacing it with grouped denoising - partitioning tokens into two groups, then allocating the first few layers to processing the unmasked tokens and the last few to processing only the masked tokens, with a cross-attention layer connecting them. This means the model can be trained to predict all tokens at once instead of 50% on average and can use a smaller attention matrix.
Something that I think could losslessly improve generation is optimizing the choice of tokens to sample and unmask at each step. The reason we can't unmask all tokens at once is that they are conditionally dependent on each other, so the unconditional prediction will be inaccurate. We could try to fix this with something like MTP adapted for encoder-only models, but a simpler way to get around this would be to sample tokens that are roughly conditionally independent. One way this could be done is by training a linear probe that embeds the final layer's activations into a lower-dimensional space, computing pairwise dot products between those low-dimensional embeddings, and training the dot products to predict the decrease in crossentropy from conditioning one token on another (by removing the mask on a group of masked tokens and computing the effect on all other masked tokens). If this is easy to predict (I expect that it is, as in the worst case the model can simply use the distance between the tokens as a proxy), we can sample the unmasked tokens by greedily unmasking the token that minimizes the interference with all other selected tokens. This would have the benefit of being a post-training optimization that doesn't interfere with the training objective.
Gradient of log density; this is the noise up to a constant factor for Gaussian denoising diffusion.
Actually computing the exact log likelihood requires finding the prior probability of the input noise and the log determinant of the Jacobian using the Hutchinson trace estimator (Reference)
Source: some research I probably won't publish.