jif: Jax text diffusion

Code: https://github.com/neverix/jif

This blog post is about some work I did in the summer of 2024. Since text diffusion models have become a hot topic with DeepMind's Gemini Diffusion and HKU NLP's Dream, I thought it would be a good idea to finish the draft for this now.

Introduction

Diffusion is currently the most popular method for training neural networks to generate images. It works by iteratively removing Gaussian noise from an image until it has the same distribution as the data. Why this works so well is a mystery, but it's likely because removing noise is, for most image distributions, the same objective as going from more to less blurry pictures autoregressively. The denoising process being iterative gives diffusion some nice properties:

There have been many attempts to apply diffusion to text. Some of them formulate text diffusion as refining a distribution on a set of correlated discrete variables, while others model discrete tokens being randomly corrupted. I will ignore all of them except for the line of work this project is an implementation of.

SEDD

SEDD by (Lou et al. 2024) is an early successful implementation of text diffusion at GPT-2 scales. It doesn't match autoregressive modelling in (upper bounds of) NLL, but it can generate coherent text and even has induction-like heads 3. Its forward diffusion process is a continuous Markov chain where the entire state of the sequence is a single giant discrete variable:

pΔn,QRn×ndptdt=Qtptp0pdatap \in \Delta^n, Q \in \mathbb{R}^{n \times n} \quad \frac{dp_t}{dt} = Q_t p_t \quad p_0 \approx p_{\text{data}}

Of course, they factorize the matrix and probability density on the token dimension:

X=1,...,N=1,...,ndx=x1...xd\mathcal{X} = { 1, ..., N } = { 1, ..., n }^d \quad \mathbf{x} = x_1 ... x_d Qt(x1...xi...xd,x1,...x^i,...xd)=Qttok(xi,x^i)\quad Q_t(x_1 ... x_i ... x_d, x_1, ... \hat{x}_i, ... x_d) = Q^{\text{tok}}_t(x_i, \hat{x}_i)

It can be reversed as:

dpTtdt=QTtpTtQt(y,x)=pt(y)pt(x)Qt(x,y)\frac{dp_{T - t}}{dt} = \overline{Q} _{T - t} p _{T - t} \quad \overline{Q}_t(y, x) = \frac{p_t(y)}{p_t(x)} Q_t(x, y) Qt(x,x)=yxQt(y,x)\overline{Q}_t(x, x) = - \sum _{y \ne x} \overline{Q}_t(y, x)

...Which is similar to the reversal of the forward process of Gaussian diffusion models based on Tweedie's formula, but instead of the score xlogpt\nabla_x \log p_t we have ratios pt(y)pt(x)\frac{p_t(y)}{p_t(x)}. The goal of training is then learning these ratios, or a factorization. The objective in Lou et al. is score entropy. Given predicted ratios sθ(x)ys _{\theta} (x)_y:

LSE=Exp[yxwxy(sθ(x)yp(y)p(x)logsθ(x)y+K(p(y)p(x)))]\mathcal{L} _{\text{SE}} = \mathbb{E} _{x \sim p} \left[ \sum _{y \neq x} w_{xy} \left( s _{\theta} (x)_y - \frac{p(y)}{p(x)} \log s _{\theta} (x)_y + K \left(\frac{p(y)}{p(x)} \right) \right) \right] where K(a)=a(loga1)K(a) = a (\log a - 1) is a normalizing constant function that ensures that LSE0\mathcal{L} _{\rm SE} \ge 0.

They prove some properties, like that minimizing it requires learning the correct ratios and that it scales down the gradients by the predicted ratios, preventing blow-ups. There is also a connection to Bregman divergences (F(p)F(q)F(q),pqF(p) - F(q) - \langle \nabla F(q), p - q \rangle for convex FF):

log(sθ(x)y)+log(p(y)p(x))(sθ(x)yp(y)p(x))(p(x)p(y))-\log\left(s _{\theta} (x)_y\right) + \log\left(\frac{p(y)}{p(x)}\right) - \left(s _{\theta} (x)_y - \frac{p(y)}{p(x)}\right)\left(-\frac{p(x)}{p(y)}\right) =log(sθ(x)y)+log(p(y)p(x))(sθ(x)y)(p(x)p(y))1=-\log\left(s _{\theta} (x)_y\right) + \log\left(\frac{p(y)}{p(x)}\right) - \left(s _{\theta} (x)_y\right)\left(-\frac{p(x)}{p(y)}\right) - 1 =log(sθ(x)y)+log(p(y)p(x))+p(x)p(y)sθ(x)y1=-\log\left(s _{\theta} (x)_y\right) + \log\left(\frac{p(y)}{p(x)}\right) + \frac{p(x)}{p(y)} s _{\theta} (x)_y - 1 =log(sθ(x)y)+p(x)p(y)sθ(x)y1+log(p(y)p(x))=-\log\left(s _{\theta} (x)_y\right) + \frac{p(x)}{p(y)} s _{\theta} (x)_y - 1 + \log\left(\frac{p(y)}{p(x)}\right) =p(x)p(y)(sθ(x)yp(y)p(x)log(sθ(x)y)+p(y)p(x)log(p(y)p(x))p(y)p(x))=\frac{p(x)}{p(y)} \left( s _{\theta} (x)_y - \frac{p(y)}{p(x)}\log\left(s _{\theta} (x)_y\right) + \frac{p(y)}{p(x)} \log\left(\frac{p(y)}{p(x)}\right) - \frac{p(y)}{p(x)} \right)

...Which is exactly the formula for SE, if we substitute in K and take into account the constraint that yxy \ne x in the sum.

SEDD considers two choices for the matrix QQ: uniform and absorbing. The uniform matrix slowly decays p(x,t)p_{(x, t)} into a uniform distribution over tokens, while with the absorbing matrix there is a small chance of each token jumping into a [MASK] state at each timestep. It is similar to BERT masking, but extended to all masking ratios and with a fancy loss instead of crossentropy. It turns out that absorbing SEDD works better - even better than GPT2 on Wikitext2 and PTB (though with a lot more training).

It turns out that SE is equivalent to crossentropy in the absorbing case up to a constant. For a single token x=MASKx = \text{MASK}, if the predicted ratio is sθ(x)ys _{\theta} (x)_y, score entroy is

yxwxy(sθ(x)yp(y)p(x)logsθ(x)y+K(p(y)p(x)))\sum _{y \neq x} w_{xy} \left( s _{\theta} (x)_y - \frac{p(y)}{p(x)} \log s _{\theta} (x)_y + K \left(\frac{p(y)}{p(x)} \right) \right)

=yMASKwxy(sθ(x)yIy=GTexp(σ)1logsθ(x)y)+C= \sum _{y \neq \text{MASK}} w_{xy} \left( s _{\theta} (x)_y - \frac{\mathbb{I} _{y = \text{GT}}}{\exp(\sigma) - 1} \log s _{\theta} (x)_y \right) + C

...Where σ=log(1t)\sigma = -\log(1-t) is the total magnitude of the noise. The probability a token is equal to noise at any given point is 1exp(σ)=t1 - \exp(-\sigma) = t. Then, exp(σ)=11t\exp(\sigma) = \frac{1}{1-t} and 1exp(σ)1=11(1t)1t=1tt\frac{1}{\exp(\sigma) - 1} = \frac{1}{\frac{1 - (1-t)}{1-t}} = \frac{1-t}{t}, which is exactly the ratio of the probabilities of the token being masked or unmasked. Note that this ratio is only nonzero when yy is the ground truth token; if it was not, the probability of transitioning to it would be zero. Continuing,

(exp(σ)1)(yMASKsθ(x)y)logsθ(x)GT\sim (\exp(\sigma) - 1) \left( \sum _{y \neq \text{MASK}} s _{\theta} (x)_y \right) - \log s _{\theta} (x) _{\text{GT}}

This objective is much simpler. The objective of the network is then to maximize logsθ(x)GT\log s _{\theta} (x) _{\text{GT}} and minimize yMASKsθ(x)y\sum _{y \neq \text{MASK}} s _{\theta} (x)_y. At this point, intuitively, this looks similar to crossentropy/negative log-likelihood, but with unnormalized logits: the network needs to allocate the score to the tokens it thinks are more likely and can't increase the total by too much. In SEDD, the prediction is actually the log-space score, so the objective becomes:

(exp(σ)1)exp(logsumexp(sθ(x)))logsθ(x)GT(\exp(\sigma) - 1) \exp(\text{logsumexp}(s _{\theta} (x))) - \log s _{\theta} (x) _{\text{GT}}

Given that logsumexp is one-dimensional, we can separate out the crossentropy subproblem and worry about optimizing it later, with the remaining objective purely in terms of it:

(exp(σ)1)exp(logsumexp(sθ(x)))logsumexp(sθ(x))(\exp(\sigma) - 1) \exp(\text{logsumexp}(s _{\theta} (x))) - \text{logsumexp}(s _{\theta} (x))

+ logsumexp(sθ(x))logsθ(x)GT+ \ \text{logsumexp}(s _{\theta} (x)) - \log s _{\theta} (x) _{\text{GT}}

...Which is exactly the crossentropy loss, plus a log-sum-exp term. The log-sum-exp is not optimized for during traditional NLL training, so this objective would be satisfied by a network minimizing crossentropy.

Notice that the score of the MASK token is not optimized for in this setup. If one inspects a trained SEDD model, the logit weights and biases corresponding to it are all zero, which would make sense given it received updates only from weight decay.

Still, SEDD was the first text diffusion model with a simple formulation and comparable performance to GPT-2. I initially based my implementation on it, the diffusion transformer being a one-to-one port of the SEDD code.

MDLM

MDLM (Sahoo et al. 2024) is the implementation of text diffusion that is used in most of the recent text diffusion models. It is a masked denoising language model with the same schedule as in SEDD, but with a simplified objective (log-likelihood, as written above).

Once MDLM came out, I ported the diffusion process over to my incomplete codebase. It showed much better results on toy tasks, and I removed SEDD.

jif

The following is a description of the codebase.

Data & Setup

The dataset is TinyStories-1M. The tokenizer is GPT-2 trimmed to the first 32k tokens to decrease memory usage for logits and embeddings. The models are trained on context windows of 128 tokens with a batch size of 256. The masking ratio is uniformly distributed between 0 and 1 with MDLM's low-discrepancy sampling (using a linspace to set timesteps instead of an RNG).

When sampling, we use 500 steps. If a network is not timestep-conditioned (and is instead conditioned on the number of unmasked tokens), it is possible to cache the logits from the previous timestep and reuse them if no token was unmasked, as described in MDLM.

Model & Sharding

The model is a diffusion transformer as described in SEDD (source code). The model code is implemented in Penzai. The network is declared largely declaratively using Penzai's pre-made components. We use Penzai's side-inputs system to pass DiT conditioning along the module tree without adding it as an argument to each module. The Penzai implementation also means all axes in the models' parameters are named and can be sharded automatically. We shard along the "neurons", "kv_heads" and "vocabulary" named axes, which correspond to the inner dimension of the transformer's MLP, attention heads and vocabulary. This means we use TP for MLP and KV heads and let Jax figure out FSDP for the vocabulary.

Note: in jax 0.6, Penzai's default embedding layer was very slow. This is because it internally runs something like vmap(lambda i: x[i]), which makes Jax use scatter-add for the backward pass, which can't be optimized very well on TPUs. I replaced the vmap-of-lookup with a single indexing operation and reached around 15% MFU.

Training

We optimize the model with AdamW. We tried Muon and Schedule-Free Adam without improvement. We keep an exponential moving average of the weights with a decay of 0.995 like in MDLM. The model generates plausible text on a v4-8 within a minute of compilation.

Future work

Normalization

This project was started with the goal of figuring out a MuP-normalized architecture for diffusion models. Specifically, I wanted to implement Modula, an architecture and optimizer combination that gives the network a Lipschitz constant of 1 at initialization and ensures that updates cannot increase the loss by more than a multiple of a specially tuned norm of the update. I ended up working on other things, but the repo contains a script for testing learning rate transfer, and Penzai is well-suited for implementing Modula on top.

Statefulness

Text diffusion has an inherent disadvantage compared to autoregression: it uses O(N3)O(N^3) compute in the number of tokens, having to do a forward pass through all tokens to sample a single token. This could turn into an advantage if it can process the context more deeply compared to autoregression, where each token's embedding has only previously sampled tokens as context. However, there is no way for any diffusion step to receive information from previous steps other than through sampled tokens, so this is difficult to do. To fix this, we could follow Bit Diffusion (Chen et al. 2022) in self-conditioning the model on detached-gradient predictions of the previous step. Still, I expect any such method to be tricky to implement and prone to instability, like any recurrent model.

Training objective

Beyond Autoregression: Fast LLMs via Self-Distillation Through Time (Deschenaux & Gulcehre, 2024) introduces a simple method for distilling discrete diffusion models for sampling with fewer timesteps that scales to 1B-scale models. It is still in question if this scales up well and preserves the diversity of generations, but it should be possible to implement in jif.

PGM improves on the masking objective for diffusion model training by replacing it with grouped denoising - partitioning tokens into two groups, then allocating the first few layers to processing the unmasked tokens and the last few to processing only the masked tokens, with a cross-attention layer connecting them. This means the model can be trained to predict all tokens at once instead of 50% on average and can use a smaller attention matrix.

Something that I think could losslessly improve generation is optimizing the choice of tokens to sample and unmask at each step. The reason we can't unmask all tokens at once is that they are conditionally dependent on each other, so the unconditional prediction will be inaccurate. We could try to fix this with something like MTP adapted for encoder-only models, but a simpler way to get around this would be to sample tokens that are roughly conditionally independent. One way this could be done is by training a linear probe that embeds the final layer's activations into a lower-dimensional space, computing pairwise dot products between those low-dimensional embeddings, and training the dot products to predict the decrease in crossentropy from conditioning one token on another (by removing the mask on a group of masked tokens and computing the effect on all other masked tokens). If this is easy to predict (I expect that it is, as in the worst case the model can simply use the distance between the tokens as a proxy), we can sample the unmasked tokens by greedily unmasking the token that minimizes the interference with all other selected tokens. This would have the benefit of being a post-training optimization that doesn't interfere with the training objective.

1

Gradient of log density; this is the noise up to a constant factor for Gaussian denoising diffusion.

2

Actually computing the exact log likelihood requires finding the prior probability of the input noise and the log determinant of the Jacobian using the Hutchinson trace estimator (Reference)

3

Source: some research I probably won't publish.