The hardest part of Reinforcement Learning isn’t the math or the complex theory or the ML frameworks. It’s what comes after the code finally runs. For me, that’s when the real challenge begins: tuning the agent.
In the past, my process was mostly guesswork: endless fiddling with hyperparameters until, by sheer luck, something worked. More often, I’d just give up.
Today, I’m ditching that approach. I’m starting by documenting my journey to find a principled, systematic way to turn a new agent into a top performer.
Coding the Agent
Today, my goal is to code a PPO Actor-Critic model in JAX from scratch to run on the Gymnax MinAtar Breakout environment. Specifically, I would like to learn JAX and start a library of stable baselines for future experiments.
For those unfamiliar, let me quickly define all this stuff:
- PPO
- What: A policy optimization algorithm that is pretty standard fare in RL these days.
- Why: It is extremely stable and simple to implement.
- Additional Reading: Nothing beats reading the paper.
- Actor-Critic
- What: A paradigm in RL in which you have a policy network making decisions and a value network producing value estimates of states.
- Why: The pair together give you the benefits of both policy optimization and value estimation.
- Additional Reading: The Reinforcement Learning Bible.
- JAX
- What: A Python library that provides complex math operations that we care about in ML
- Why: JAX does these operations very efficiently and can run them in parallel on GPUs, meaning once you’ve got the code up and running, iterating can be very fast. This is great! When you iterate quickly, you can fail fast and learn fast.
- Additional Reading: JAX Documentation
- Gymnax MinAtar Breakout
- What: Gymnax is a standard set of RL environments coded in JAX that run on the GPU, meaning, just like with JAX, they run very efficiently and can run in parallel on GPUs.
- Why: You guessed it! We can iterate very fast, fail fast, and learn fast.
- Additional Reading: The Gymnax GitHub Page
I’ve only ever vibe-coded JAX before and haven’t done functional coding in 10 years so I’m expecting a steep learning curve. This is a far cry from my home base, Tensorflow, where there aren’t any hard requirements; you just use TF functions.
Okay, let’s start coding! Things seem to be going well. To get the feel, I’m ignoring JIT-ting for now. JIT (just-in-time compiling) is an option in JAX that converts your raw Python into compiled XLA (Accelerated Linear Algebra) which essentially makes your code run extremely fast but has some steep requirements: Every operation should be pure without any side effects and state is passed around as parameters and returned at the end of the function.
Okay, I think things are working. But it’s very slow. Maybe slower than TF. Let’s start using JIT, the real power of JAX.
Okay, I’m getting the hang of this. But, at one point I put a for-loop inside a JIT’ed function which made my computer run like molasses for about 30 minutes while the XLA compiler turned the whole for-loop rollout into a computation graph! Turns out that is a big no-no. I now know the benefit of vmap and scan. With that fixed let’s see if I can’t finish this.
Okay, I have gotten the Gymnax Breakout Agent coded up and running in JAX! Turns out, once you get used to the functional style of thinking, it is pretty straight forward. We’ll save running the full training process until a bit later.
First, there’s one interesting tidbit that struck me while implementing the PPO loss that I think is worth exploring deeper.
A Mystery
Whenever I’ve looked up implementations of PPO in the past, I have always been puzzled because the code doesn’t exactly match the paper’s definition. Here’s my final PPO loss function as an example:
def policy_loss(logits: jnp.ndarray, old_logits: jnp.ndarray, advantages: jnp.ndarray, actions: jnp.ndarray) -> float:
one_hot_actions = jax.nn.one_hot(actions, num_classes=logits.shape[-1])
log_probs = jax.nn.log_softmax(logits)
log_probs = jnp.sum(log_probs * one_hot_actions, axis=-1)
log_probs_old = jax.nn.log_softmax(old_logits)
log_probs_old = jnp.sum(log_probs_old * one_hot_actions, axis=-1)
ratio = jnp.exp(log_probs - log_probs_old)
loss = jnp.minimum(ratio * advantages, jnp.clip(ratio, 1 - _CLIP_EPSILON, 1 + _CLIP_EPSILON) * advantages)
loss = -jnp.mean(loss)
return loss
And the PPO clipped loss definition from the paper:
Do you see the difference? Something with the ratio?
Here’s the difference: For the ratio , I’m using a
log_softmax function which returns log probability instead of raw probability. Then I subtract the two log probabilities. In the original equation, they just divide by the original, non-log values. Mathematically, these are equivalent. So why calculate the log first? I’ve always blindly done it, assuming if everyone else does it, it must be right. But this blog is about becoming better at RL, so this time, we’ll dig deeper.
The Log-Sum-Exp Trick
Okay, so after some digging, I’ve found the reason. The claim is that converting to log is more stable than keeping the original ratios. Let’s look at the softmax definition that we use to get our probabilities from logits:
Notice how we’re using a lot of exponents. Turns out, exponents quickly become too small for computers to handle in floating point operations.
is an astronomically small number. It’s so small that a standard computer can’t store it and rounds it down to exactly 0.0. This is underflow.
But -800 is a pretty reasonable number to expect out of a neural network. Hmm… that seems bad.
Here are a couple logit vector examples to illustrate the point. Let’s calculate the probability of the first action for each of these examples using our Softmax equation from above.
Logits From Model: [-750, 0]Underflow: |
Logits From Model: [-750, -800]Underflow: Both nan in Python |
Uh, oh! In a normal softmax, some not-so-large negative numbers are already causing nan! How can logs save us from this?
The log-space version is much safer because it uses a trick to control how large the numbers are that we’re applying the exponent to so we don’t get this underflow. Check out the log softmax:
Notice how we’re subtracting the maximum logit value from all logits before exponentiating. This is called the log-sum-exp trick and it completely fixes our underflow issue!
Let’s look at our nan example again:
Logits From Model: [-750, 0]Underflow: |
Logits From Model: [-750, -800] |
We got rid of our nan! And even better, because we later apply an exponent to these, ratio = jnp.exp(log_probs - log_probs_old), any issues with taking logs of 0 disappear as well. Should we visualize what we’re talking about here? I think so.
I threw together a little Python script (with the help of Gemini) that sweeps over values of logits to be passed to our PPO loss, keeping a fixed ratio between them. I calculate the loss both in log and in the original form. Let’s see what it looks like:

🔍Analysis
We can clearly see as the logits become large magnitude negative numbers, the log version remains stable while the original version becomes incorrect before producing nans. In a real ML situation, nan losses will destroy your whole training run!
💡Takeaway
When you see a sum of exponents (or a softmax) that can have large magnitude numbers, think “Log-Sum-Exp Trick.”
This raises another question in my mind, though. It is clear that at the far end of magnitudes, the log_softmax function is more stable. But what about the “safer” magnitudes? Which of these two methods is more precise in the safe zone? Are we making a tradeoff between high level stability and small scale precision errors? Well, we can visualize this too!

🔍Analysis
First of all, the obvious takeaway is that both of these methods at this scale of zoom are not very precise. But we’re looking at roughly scale differences here.
Another thing to note: If you look closely at the direct method, the error isn’t constant; it’s noisy. This is likely due to accumulated floating-point rounding errors.
In contrast, the log-space version’s error is flat and consistent. This suggests a more systematic precision error from the underlying log and exp functions. The error is predictable, not random.
💡Takeaway
We’re looking at a trade-off between two types of tiny errors. The direct method introduces unpredictable, noisy rounding errors, while the log-space method has a predictable, systematic error. For scientific computing and machine learning, predictable behavior is always preferred. This makes the log-space method the clear winner on a theoretical level.
Does this make a difference in practice? Let’s put a pin in that. Once we have some baseline agents up and running, we can do some empirical studies to find out.
The Agent
Okay so, the agent is coded. We’re using the log-sum-exp trick to make sure our loss is stable, even with large logits. Let’s train this puppy for 1 million steps and see how it does!
Average Episode Reward over 256 runs: 3.98
Okay. Hmm, is that good? Seems low. Let’s check the GitHub page for Gymnax. They have baseline PPO agent rewards for all their environments. 28
So we’re looking at 3.98/28
If this were an exam, we’d be getting an F. Well, dang… no surprise there, I guess. It’s extremely unlikely I would get all the hyperparameters just right to get a great agent by happenstance.
This is the point I would normally start blindly fiddling with hyperparameter knobs. This time I will resist the urge and find a more methodical approach. But I’ll have to save that for next time when I dive head-first into visualizing everything about my agent’s behavior to uncover how to improve the performance.
Code
You can find the complete, finalized code for this post on the V1.1 Release on GitHub.
Leave a comment