MLPs are the most basic type of deep neural networks. They are quick, easy, and cheap. CNNs are quick, easy, and cheap too but I didn’t want to jump the gun and immediately use them. But today, we’re going to make the leap from basic to bespoke, using an architecture that was designed for the type of structured data we’re getting from our MinAtar Breakout environment.
Why Now?
You might be wondering why I’m making the switch now, especially after setting up a whole investigative process in the last few posts. The truth is, my original plan didn’t quite work out.
I was waiting for the data to scream “Use a CNN!”, but that moment never came. And that led me to a new headspace. Some things are just known to be better for certain tasks; I don’t need to independently recognize the need for these things. Some examples of things that fall in this category:
- CNNs for image data (or ViTs)
- RNNs or Transformers for sequence data
- Dropout to prevent overfitting
- Using activations like ReLUs to keep the network stable
So, going forward, I’m not going to try and justify their use. We will likely do some deep dives into each of these to build our intuition and understanding about them but I’m going to make the ruling that we don’t need data to prove we should try them in our tasks.
Alright, let’s dive in and as we’ll soon see, what begins as a simple question about architecture will quickly turn into a deep dive into the surprising sources of training instability.
Where We Left Off
If you missed the last post, here is a quick rundown:
Last time, we did a deep dive into entropy bonuses and how to recognize when you need to adjust it based on some simple checks.
Using this new knowledge, we fixed our Gymnax MinAtar Breakout agent’s entropy bonus and ended up with an agent that got an average episodic reward over 1024 rollouts of 15.53 compared to the baseline on the Gymnax GitHub page of 28.
We’re within striking range, so let’s see if today we can’t break that benchmark.
Flip That Switch!
Okay, switching from MLPs to CNNs is really easy in my codebase. It’s literally changing this:
class Policy(nn.Module):
action_space: int
@nn.compact
def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
x = obs.reshape((obs.shape[0], -1))
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(self.action_space)(x)
return x
class Value(nn.Module):
@nn.compact
def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
x = obs.reshape((obs.shape[0], -1))
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
To this:
class Policy(nn.Module):
action_space: int
@nn.compact
def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
x = obs
x = nn.Conv(features=16, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(self.action_space)(x)
return x
class Value(nn.Module):
@nn.compact
def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
x = obs
x = nn.Conv(features=16, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=64, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
That’s it! Now we’re using an architecture designed for our input data structure. So, thanks for reading this week folks. See you next time!
Actually, wait. There’s more to do. Did I pick the right kernel size? Hmm… not sure. Did I pick the right strides? The right padding? Oh, gosh! What about the number of CNN layers? And fully connected layers? Do we even need fully connected layers!? How does any of this affect the performance of our agent!?!
Okay, okay. Let’s systematically work through these problems and see if we can’t come up with some general principles for how to answer them in the future.
Can It Be Easy, Please?
I recently read the paper released by OpenAI in 2020 about the scaling laws of language models. They essentially predicted the power of LLMs two years before ChatGPT! Really interesting paper. But how is that related to the work we’re doing here with CNNs in a MinAtar Breakout environment?
Well, in that paper, one of the findings was that the actual architectural choices of LLMs don’t matter all that much. The really important factor is the total number of parameters in the LLM. So, without a better post to hitch to, let’s start there with our own exploration. So, a hypothesis:
🤔Hypothesis
Total Parameters Hypothesis
If we maintain the same number of total parameters in our policy and value network but modify the architecture, the agent will get roughly equal average episodic rewards after training because the network only needs the capacity and the actual architectural decisions do not affect capacity.
I named the hypothesis because we’ll actually be working with more than one in this post. Total parameters being the only important factor seems like a stretch to me, but we’ve got to start somewhere. So let’s get to testing our new hypothesis.
First, we’ll need a parameter budget. Flax, a library built in JAX and used for neural network definitions provides an easy way to count the parameters for a given model by using the built in tabulate function. So, using that, let’s see how many parameters our new, arbitrarily selected CNN network has.
155,315
Okay. So that’s our parameter budget. Now, I’m tempted to come up with ~10 different architectures all with roughly the same parameter count and run them all fully and compare the final results. But a quick and easy test would be to just come up with 1 other architecture. If that one architecture performs differently, then we can disprove the hypothesis. Fail fast and learn fast.
Sooo… ahem. “Gemini, how do I come up with multiple architectures with the same number of parameters? Hmm, there’s no systematic way to do it?” Okay… brute force? Give me a sec…
Okay, looking at the details of the first model (let’s call this the Base Model), we can see that the majority of the parameter count comes from the first dense layer. 131,200 in total. So my goal is to increase the parameter count of the convolutional layer section and equally reduce the parameter count of the dense layers.
Okay. It’s not pretty. But I got something together.
class Policy(nn.Module):
action_space: int
@nn.compact
def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
x = obs
x = nn.Conv(features=8, kernel_size=(3, 3,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=8, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=16, kernel_size=(4, 4,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=16, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=16, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=16, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='VALID')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = nn.Conv(features=32, kernel_size=(4, 4,), strides=(1, 1,), padding='SAME')(x)
x = nn.relu(x)
x = x.reshape((x.shape[0], -1))
x = nn.Dense(128)(x)
x = nn.relu(x)
x = nn.Dense(self.action_space)(x)
return x
This monstrosity, let’s call him Jumbo, has exactly 155,763 parameters compared to our base model’s 155,315. That’s only about a 0.3% difference between the two models so this should be a good test. Let’s compare their performance!
| Model | Avg Episodic Reward | Min Episodic Reward | Max Episodic Reward |
|---|---|---|---|
| Base | 66.37 | 51 | 84 |
| Jumbo | 6.47 | 6 | 7 |
🔍Analysis
Okay, clearly there is a major difference in performance between these two models despite their near-exact parameter count. The base model which is highly weighted towards the fully-connected network performs much better than our poor Jumbo network, which is mostly convolutions.
First off, 66.37! We just crushed the baseline, 28. Be right back. I’m gonna dance around my office.
Now the data is pointing towards disproving the hypothesis pretty squarely. But you might be thinking, “Wait a second, the new model is massively deeper than the baseline model. Maybe its just an issue of exploding or vanishing gradients that’s causing much worse performance.” First of all, please call him by his proper name, Jumbo. And second, you might be right. So, let’s take a look at how the gradient is flowing through Jumbo.
I created a new visualization using Matplotlib that looks at the norm of the gradient, essentially how much the parameters are changing, for each of the layers in our model over the course of training the model.
Here it is for Jumbo:

🔍Analysis
If we were seeing an exploding or vanishing gradient, we would expect there to be big swings in the data and if any of our layers were not receiving signal, we’d expect to see much more separation between the layers. Things look healthy here.
But let’s just double check. Since lots of our CNN layer outputs are equally sized, it’s easy to add skip connections which are a known solution for vanishing gradients. Let me add skip connections where I can and let’s kick off one more training run. I’ll call this skip connection version Jumbo+.
| Model | Avg Episodic Reward | Min Episodic Reward | Max Episodic Reward |
|---|---|---|---|
| Jumbo+ | 6.46 | 6 | 7 |
🔍Analysis
Skip connections appear to make no difference in final reward performance for Jumbo.
Well, it seems those results quickly disprove my hypothesis! But let’s see if we can’t find some important information about why the hypothesis failed so we can make a better new hypothesis.
First, let’s just look at their losses and average episodic rewards over time to make sure they’re stable. I threw together a little Matplotlib visualization of the agent’s actor and critic loss and average reward over time across the whole train process.
First, the baseline:

🔍Analysis
Whoa. Look at how volatile the losses and rewards get after about 2000 training steps. I definitely didn’t expect this based on just the final reward.
It seems like the agent is pretty stable initially but then discovers some new policy. Once in this new policy regime, the learning rate seems too high and the loss starts jumping all over.
Well, my guess would be that this loss instability is an issue with the scale of rewards as the policy gets better. So, we have another hypothesis:
🤔Hypothesis
Reward Scaling Hypothesis
If we normalize the reward signal during training, the agent will get better average episodic rewards after training because the normalization will resolve loss fluctuations that would otherwise be present as rewards scale, thus providing a more stable learning signal.
Great, another hypothesis! Before we dig into that, though, let’s see if we’re having the same problem with Jumbo. Perhaps our Total Parameters Hypothesis isn’t disproven yet!
Here’s Jumbo:

🔍Analysis
Well we might be seeing be the same thing here. The critic loss in particular is stable for a time and then becomes very unstable. This time it happens much quicker, at around 500 updates.
Okay, both models will clearly benefit from some better loss stability. Let’s put a pin in our Total Parameters Hypothesis while we resolve our training stability issues.
An Unfair Advantage
This hypothesis, that our loss instability is tied to the scale of rewards, gives us a clear direction for our search. Let me see what I can find…
Okay, so this led me down a rabbit hole of digging into community best practices for PPO stability. It turns out, this is a well-known issue, and top practitioners have a standard solution. While the original PPO paper uses Generalized Advantage Estimation (GAE) directly, many influential implementations, like those in OpenAI Baselines, add one crucial step: they batch-normalize the advantages before using them in the loss calculation.
This is exactly the kind of technique that would test our Reward Scaling Hypothesis. It’s a direct method for keeping the gradient signal stable, even as rewards go up. Let’s add it to our code and see if the results support our hypothesis…
Adding the normalization is as simple as adding this line of code to our loss calculation:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
Okay, let’s run our base model and Jumbo with this new batch-normalized advantage and see if anything changes.
Uhm, wow. Check this out:
| Model | Avg Episodic Reward | Min Episodic Reward | Max Episodic Reward |
|---|---|---|---|
| Base | 84 | 84 | 84 |
| Jumbo | 10.38 | 7 | 14 |
🔍Analysis
Obviously, we’ve got a huge improvement again for our base model. I’d guess we still have a little headroom but we’re getting better and better.
Jumbo is still struggling comparatively but has also improved. So batch-normalized advantage is clearly beneficial here.
💡Takeaway
Batch-normalized advantage can provide a major lift in performance for PPO agents because it stabilizes the reward signal. This will become one of my “some things are just known to be better for certain tasks” type algorithmic decisions that I was talking about earlier.
Another huge jump and data that supports our hypothesis! We’re well above the baseline listed on the Gymnax GitHub page. Our 84 vs their 28. I want to just take a moment to revel in this moment. Before starting this blog, I would have gotten my initial 3.89 average episodic reward I called out in my first breakout post, started fiddling with hyperparameters, and maybe improved a bit but likely wouldn’t have gotten beyond 28. This huge improvement is a major reinforcement that my plan to systematically understand, test, and visualize before making changes rather than blind guessing is the best route, even with the extra time it takes to understand, test, and verify.
And speaking of verifying… I was so excited by this result that I immediately started working on the next steps. But as I kept digging into the new loss curves, something still felt wrong. It’s that feeling you get when the numbers look good, but the underlying behavior still seems unstable. That nagging feeling forced me to go back and audit my advantage calculation one more time.
And that’s when I found it.
A Confession
I just spent a long time (days) doing some tests after making this batch-normalized advantage update and trying to resolve some issues that seemed to appear after this update with the critic’s loss signal. I even had another post and a half already written about the whole process. But it turns out it was all for naught. Because I had a major bug in my advantage calculation.
The bug is pretty subtle and won’t be very clear by looking at the code so I’ll try and explain it. Essentially, when calculating discounted advantage, I’ve been using the wrong observation.
The observation we use to calculate our returns for what should have been the final observation in the rollout was actual the first observation. So we were starting the whole calculation with an incorrect reward! This means the entire advantage signal my agent was learning from was based on a state it wasn’t even in. The ‘good’ performance wasn’t just a fluke; it was based on nonsensical data. Shooooot!
Yeah, I think it caused a whole bunch of issues. And I thought I was being really clever trying to fix the loss signal issues by adding clipping on the advantage part of the PPO loss and introducing learning rate annealing. And none of it was working. I had a bunch of takeaways about how those techniques aren’t well suited for my environment.
Turns out it was just a logical bug in my code. Dang!
So rather than have you all read through all that, which, in light of the bug, is mostly meaningless, I thought I’d cut it all and come back where it makes sense to fix it.
Squash That Bug!
Okay. I’ve fixed this dumb bug. I’m now using the correct observations for my advantage calculations. Let’s see if that improves our performance even more. Okay, hmm… That bug must have been compensating for some deeper issues with our critic.
| Model | Avg Episodic Reward | Min Episodic Reward | Max Episodic Reward |
|---|---|---|---|
| Base | 9.41 | 7 | 12 |
| Jumbo | 7.52 | 7 | 8 |
🔍Analysis
Okay, so the bad news is that our results are much worse after we fix the bug. Our scores have now dropped precipitously back down to well below the baseline and our high of 84.
That is really disappointing. But the correct action here is not to revert the code, of course. The 84 score was real, but it was achieved with a broken learning signal. This means the bug was unintentionally acting as a strange form of regularization that, by pure chance, worked incredibly well.
So the real mystery isn’t why our agent is failing now, but what that bug was accidentally doing right that we need to replicate in a principled way.
Wrap Up
What a roller coaster. We started with a simple plan to test CNNs, skyrocketed to a score of 84 with a community-best-practice, and then came crashing back down after discovering a critical bug.
The takeaway is clear and crucial:
💡Takeaway
Be vigilant for bugs, not just when your agent is struggling, but especially when it seems to be performing unexpectedly well.
It’s a tough lesson, but a vital one. Now, with correct code, the path forward is clear. We need to become bug detectives. In the next post, we’ll perform a forensic analysis of the old, buggy agent to understand the bizarre signals it was learning from. By understanding the why behind the fluke, we can hopefully find a principled way to get back to that high score, and hopefully beyond!
Code
You can find the complete, finalized code for this post on the v1.4 Release on GitHub.
Leave a comment