A visualization of a fact about probability

  • 3rd Jan 2025
  •  • 
  • 4 min read
  •  • 
  • Tags: 
  • math
  • sketch

Note: I no longer understand why I thought it would be a good idea to write this post with just this content, but here it is.

Suppose you have a variable which may take $k$ discrete values. You may observe it, with a chance of confirming the variable does not take a certain value specified by a vector $v \in \mathbb{R}^k$. Once only one option for the variable remains, the sampling process stops.[^1] You have some prior expectation of what value the variable has, a vector $p \in \Delta^k$. What do you expect the graph of the probability of the variable having each value to be?

It is obvious that the variable with the highest prior probability will have the highest probability in the beginning. My intuitive understanding of the problem tells me that the line with the lowest probability of being observed and disproven will rise to the top the fastest and stay there. Is this understanding correct?

Below is a simulation of one instance of this problem in Numpy.

from matplotlib import pyplot as plt
import numpy as np
prior = np.asarray([10, 2, 4, 3])
prior = prior / prior.sum()
update = np.asarray([0.9, 0.98, 0.97, 0.975])
k = np.arange(200)
unnorm = prior * np.power(update[None, :], k[:, None])
prob = unnorm / unnorm.sum(axis=-1, keepdims=True)
for i, l in enumerate(prob.T):
    plt.plot(l, label=f"Line {i+1}")
plt.xlabel("Time")
plt.ylabel("Probability")
plt.legend()
plt.show()

...This is a pretty weird shape! It seems like there is a single peak on this plot for each of the possible values of the variable. This is a counterexample to my intuition, even though there exist cases where it generates correct predictions.

The explanation for this is very simple. Let's model the updating process in the space of log probabilities. Doing a step of updating consists of adding the logarithm of the probability of disproving each observation and applying log softmax to renormalize the log probabilities of each value so they add up to one when exponentiated.

Log softmax works by subtracting the logarithm of sums of exponents from each value in the distribution. Let's denote it as $l(x: \mathbb{R}^k) = x - \log \sum_{i=1}^k \exp x_i$ and prove that $l(l(x + y) + z) = l(x + y + z)$ - in other words, that addition distributes over logsumexp. $l(l(x + y) + z) = l(((x + y) + c) + z)$, where $c \in \mathbb{R}$. Now we need to prove that $l(a + c) = l(a)$. We can do this by direct computation: $l(x + a) = (x + a) - \log \sum_{i=0}^k \exp (x_i + a) = (x + a) - \log (\exp a) \sum_{i=0}^k \exp x_i = x + a - \log \exp a - \log \sum_{i=0}^k \exp a$. $\blacksquare$, for good measure.

So, after $t$ steps of updating our log probability is $l(\log p + t \log v)$. If we ignore the $l$, the expression is linear in $t$. Moreover, we know that each $\log v_i < 0 : 0 < i < k$. This means that we have a set of downward sloping lines in unnormalized log space. It seems obvious when framing it this way that some of them may become the

unnorm_logprob = np.log(prior) + np.log(update[None, :]) * k[:, None]
for i, l in enumerate(unnorm_logprob.T):
    plt.plot(l[:100], label=f"Line {i+1}")
plt.xlabel("Time")
plt.ylabel("Unnormalized log probability")
plt.legend()
plt.show()

..It is kind of hard to see the top log probability switch in this example, even when zoomed in to half the observations. However, it conveys the basic idea: we can more easily predict when lines will take over each other on a log scale plot without normalization. With normalization it would look much less obvious:

unnorm_logprob = np.log(prior) + np.log(update[None, :]) * k[:, None]
logprob = unnorm_logprob - np.log(np.sum(np.exp(unnorm_logprob), axis=-1, keepdims=True))
for i, l in enumerate(logprob.T):
    plt.plot(l[:100], label=f"Line {i+1}")
plt.xlabel("Time")
plt.ylabel("Normalized log probability")
plt.legend()
plt.show()

The fact that the lines determine the points where each possibility becomes dominant with a simple formula reminds me of a Legendre transformation. Perhaps Hough transforms could be adapted for finding the likelihoods exactly given noisy data. I may extend this blog post if I get to that.