PufferLib is a library for sane and simple reinforcement learning. Our key features are:
Clean PuffeRL: Our training demo is a turbocharged version of CleanRL's PPO + LSTM with severalfold performance improvements and major quality of life upgrades. Train at 1M+ steps/second and easily run hyperparameter sweeps powered by CARBS on dozens of environments.
Vectorization: Synchronous and asynchronous parallel simulation at millions of steps per second. Our multiprocessing backend has native multiagent support and can be over 10x faster than Gymnasium's for some environments.
Emulation: Compatibility tools that make working with complex environments a breeze. Your environment will still be in Gymnasium/PZ format, but it will use a subset of the API that is easier for most libraries to deal with. Even if you use nothing else from PufferLib, this layer is worth your time.
Ocean: Our first-party suite of ultra performant environments written in C. They use our native PufferEnv API and each run 1M+ steps/second per CPU core. New environment contributions welcome!
This tutorial contains everything you need to start doing RL 100x faster than most of the field. We also highly encourage you to read the PufferLib source. It's not like other RL libraries: PufferLib does what it does in a few thousand lines of very simple code. It is also much easier to get help when you're stuck. Our Discord community is active and helpful.
Installation
CONDA USERS: Be warned, Conda's C compiler is fundamentally broken. We have hacked around it so Ocean envs will at least run, but most will be several times slower than if you install without Conda.
PufferTank (Recommended)
PufferTank is a prebuilt GPU Docker image with PufferLib and dependencies for all environments in the registry, including some that are slow and tricky to install. If you have not used containers before and just want everything to work, clone the repository and open it in VSCode. You will need to install the Dev Container plugin as well as Docker Desktop. VSCode will then detect the settings in .devcontainer and set up the container for you. Neovim (btw) is also included.
git clone https://github.com/pufferai/puffertank
cd puffertank
bash docker.sh test# Run image. Downloads ~12GB on the first run.
Pip Installation (core only, no training demo)
pip install pufferlib
pip install pufferlib[cleanrl,atari] # Extras for this tutorial
Source Installation
git clone https://github.com/pufferai/pufferlib
cd pufferlib
pip install -e .
pip install -e .[cleanrl,atari] # Extras for this tutorial
Training Demo
PufferLib includes a few different training scripts. You will notice that they are in the base repository, not in the pip package. This is because they are based on CleanRL, which is not meant to be packaged. Rather, it is a white-box library that is meant to be copied and edited to suit your needs. Integrating PufferLib only requires changing a few lines:
python cleanrl_ppo_atari.py
This is already much faster than the original CleanRL code, but it is still several times slower than our main training demo. Some basics:
python demo.py
--mode [train, eval, sweep-carbs]
--env [env_name]
--vec [serial, multiprocessing, native, ray]
--track # Track on Weights and Biases. Set your username in demo.py
--help# display a full list of options# Get help. One important arg: --train.device cpu if you don't have a GPU
python demo.py --help# Get help on a specific environment
python demo.py --help --env puffer_snake
# Train breakout with multiprocessing (24 cores):
python demo.py --mode train --env breakout --vec multiprocessing
# Run a hyperparameter sweep on Ocean pong. Requires carbs (github pufferai/carbs):
python demo.py --mode sweep-carbs --env puffer_pong
# Train Ocean snake with wandb logs:
python demo.py --env puffer_snake --mode train --track
# Set train and env params from cli:
python demo.py --env puffer_snake --mode train --train.learning-rate 0.001 --env.vision 3
# Eval a pretrained baseline model:
python demo.py --env puffer_snake --mode eval --baseline
# Eval an uninitialized policy:
python demo.py --env puffer_snake --mode eval# Eval a local checkpoint:
python demo.py --env puffer_snake --mode eval --eval-model-path your_model.pt
# Useful for finding your latest checkpoint:ls -lt experiments | head
Compared to the original CleanRL code, our demo file (which loads clean_pufferl.py) supports asynchronous on-policy vectorization, better multi-agent training, a convenient cli dashboard, better WandB log and sweeps integration, and more. It's only around 1000 lines of code, most of which is logging.
Vectorization
In RL, vectorization refers to the process of simulating multiple copies of an environment in parallel. Our Multiprocessing backend is fast -- much faster than Gymnasium's in most cases. Atari runs 50-60% faster synchronous and 5x faster async by our latest benchmark, and some environments like NetHack can be 10x faster even synchronous, with no API changes. Here's how to create a vectorized environment. Note to mac users: your OS doesn't like to run subprocesses without __main__.
from pufferlib.environments import atari
env_creator = atari.env_creator('breakout')
import pufferlib.vector
vecenv = pufferlib.vector.make(
env_creator, # A callable (class or function) that returns an env
env_args: None, # A list of arguments to pass to each environment
env_kwargs: None, # A list of dictionary keyword arguments to pass to each environment
backend: Serial, # pufferlib.vector.[Serial|Multiprocessing|Native|Ray]
num_envs: 1, # The total number of environments to create
**kwargs # extra backend-specific options
)
# Make 4 copies of Breakout on the current process
vecenv = pufferlib.vector.make(env_creator, num_envs=4,
backend=pufferlib.vector.Serial)
# Make 4 copies of Breakout, each on a separate process
vecenv = pufferlib.vector.make(env_creator, num_envs=4,
backend=pufferlib.vector.Multiprocessing)
# Make 4 copies of Breakout, 2 on each of 2 processes
vecenv = pufferlib.vector.make(env_creator, num_envs=4,
backend=pufferlib.vector.Multiprocessing, num_workers=2)
# Make 4 copies of Breakout, 2 on each of 2 processes,# but only get two observations per step
vecenv = pufferlib.vector.make(env_creator, num_envs=4,
backend=pufferlib.vector.Multiprocessing, num_workers=2,
batch_size=2)
# Make 1024 instances of Ocean breakout on the current processfrom pufferlib.ocean import Breakout
vecenv = pufferlib.vector.make(Breakout,
backend=pufferlib.vector.Native,
env_kwargs={'num_envs': 1024},
)
# Notice that Native envs handle multiple instances internally.# You can still multiprocess/async, but don't make multiple external# copies per process.
vecenv = pufferlib.vector.make(Breakout, num_envs=2,
backend=pufferlib.vector.Multiprocessing, batch_size=1)
# Synchronous API - reset/stepimport time
vecenv = pufferlib.vector.make(Breakout, num_envs=2,
backend=pufferlib.vector.Multiprocessing)
vecenv.reset()
start, steps, TIMEOUT = time.time(), 0, 3while time.time() - start < TIMEOUT:
vecenv.step(vecenv.action_space.sample())
steps += 1
vecenv.close()
print('Puffer FPS:', steps*vecenv.num_envs/TIMEOUT)
# Async API - async_reset, send/recv# Call your model between recv() and send()
vecenv = pufferlib.vector.make(Breakout, num_envs=2,
backend=pufferlib.vector.Multiprocessing, batch_size=1)
vecenv.async_reset()
start, steps, TIMEOUT = time.time(), 0, 3while time.time() - start < TIMEOUT:
vecenv.recv()
vecenv.send(vecenv.action_space.sample())
steps += 1
vecenv.close()
print('Puffer Async FPS:', steps*vecenv.num_envs/TIMEOUT)
Our vectorization works on almost any Gymnasium/PettingZoo environment, not just the ones we have bound manually. All you have to do is wrap your environment with our Emulation layer, covered in the next section. PufferLib outperforms other vectorization implementations by implementing the following optimizations:
A Python implementation of EnvPool. Simulates more envs than are needed per batch and returns batches of observations as soon as they are ready. Requires using the async send/recv instead of the sync step API.
Multiple environments per worker. Important for fast environments.
Shared memory. Unlike Gymnasium's implementation, we use a single buffer that is shared across environments.
Shared flags. Workers busy-wait on an unlocked flag instead of signaling via pipes or queues. This virtually eliminates interprocess communication overhead. Pipes are used once per episode to communicate aggregated infos.
Zero-copy batching. Because we use a single buffer for shared memory, we can return observations from contiguous subsets of workers without ever copying observations. Only does not work for full-async mode.
Native multiagent support. It's not an extra wrapper or slow bolt-on feature. PufferLib treats single-agent and multi-agent environments the same. API differences are handled at the emulation level.
Emulation
Complex environments may have heirarchical observations and actions, variable numbers of agents, and other quirks that make them difficult to work with and incompatible with standard reinforcement learning libraries. PufferLib's emulation layer makes every environment look like it has flat observations/actions and a constant number of agents. Here's how it works with NetHack and Neural MMO, two notoriously complex environments.
The wrappers give you back a Gymnasium/PettingZoo compliant environment. There is no loss of generality and no change to the underlying environment. You can wrap environments by class, creator function, or object, with or without additional arguments. These wrappers enable us to make some optimizations to vectorization code that would be difficult to implement otherwise.
Policies
You don't want another Policy API so we don't have one. We Just write normal PyTorch code. We do provide:
Default policies: A small collection of broadly useful networks. These include MLPs and CNNs.
LSTM Integration: Break your forward() function into encode_observations() and decode_actions() and our LSTM wrapper will handle recurrance for you
CleanRL API compatibility: Wrappers that format your policy for usage with CleanRL. We use these in our demos. This is mostly fluff -- we're working on cutting down boilerplate.
import torch
from torch import nn
import numpy as np
classDefault(nn.Module):
'''Default PyTorch policy. Flattens obs and applies a linear layer.
PufferLib is not a framework. It does not enforce a base class.
You can use any PyTorch policy that returns actions and values.
We structure our forward methods as encode_observations and decode_actions
to make it easier to wrap policies with LSTMs. You can do that and use
our LSTM wrapper or implement your own. To port an existing policy
for use with our LSTM wrapper, simply put everything from forward() before
the recurrent cell into encode_observations and put everything after
into decode_actions.
'''def__init__(self, env, hidden_size=128):
super().__init__()
self.hidden_size = hidden_size
self.is_multidiscrete = isinstance(env.single_action_space,
pufferlib.spaces.MultiDiscrete)
self.is_continuous = isinstance(env.single_action_space,
pufferlib.spaces.Box)
try:
self.is_dict_obs = isinstance(env.env.observation_space, pufferlib.spaces.Dict)
except:
self.is_dict_obs = isinstance(env.observation_space, pufferlib.spaces.Dict)
if self.is_dict_obs:
self.dtype = pufferlib.pytorch.nativize_dtype(env.emulated)
input_size = sum(np.prod(v.shape) for v in env.env.observation_space.values())
self.encoder = nn.Linear(input_size, self.hidden_size)
else:
self.encoder = nn.Linear(np.prod(env.single_observation_space.shape), hidden_size)
if self.is_multidiscrete:
action_nvec = env.single_action_space.nvec
self.decoder = nn.ModuleList([pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, n), std=0.01) for n in action_nvec])
elifnot self.is_continuous:
self.decoder = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
else:
self.decoder_mean = pufferlib.pytorch.layer_init(
nn.Linear(hidden_size, env.single_action_space.shape[0]), std=0.01)
self.decoder_logstd = nn.Parameter(torch.zeros(
1, env.single_action_space.shape[0]))
self.value_head = nn.Linear(hidden_size, 1)
defforward(self, observations):
hidden, lookup = self.encode_observations(observations)
actions, value = self.decode_actions(hidden, lookup)
return actions, value
defencode_observations(self, observations):
'''Encodes a batch of observations into hidden states. Assumes
no time dimension (handled by LSTM wrappers).'''
batch_size = observations.shape[0]
if self.is_dict_obs:
observations = pufferlib.pytorch.nativize_tensor(observations, self.dtype)
observations = torch.cat([v.view(batch_size, -1) for v in observations.values()], dim=1)
else:
observations = observations.view(batch_size, -1)
return torch.relu(self.encoder(observations.float())), Nonedefdecode_actions(self, hidden, lookup, concat=True):
'''Decodes a batch of hidden states into (multi)discrete actions.
Assumes no time dimension (handled by LSTM wrappers).'''
value = self.value_head(hidden)
if self.is_multidiscrete:
actions = [dec(hidden) for dec in self.decoder]
return actions, value
elif self.is_continuous:
mean = self.decoder_mean(hidden)
logstd = self.decoder_logstd.expand_as(mean)
std = torch.exp(logstd)
probs = torch.distributions.Normal(mean, std)
batch = hidden.shape[0]
return probs, value
actions = self.decoder(hidden)
return actions, value
import pufferlib.vector
from pufferlib.ocean import Breakout
vecenv = pufferlib.vector.make(Breakout, backend=pufferlib.vector.Native)
policy = Default(vecenv.driver_env)
obs, _ = vecenv.reset()
obs = torch.as_tensor(obs)
# Use the PyTorch policy raw
actions, value = policy(obs)
# Use our LSTM compatibility layerfrom pufferlib.models import LSTMWrapper
lstm_policy = LSTMWrapper(vecenv.driver_env, policy)
state = (
torch.zeros(1, 1, lstm_policy.hidden_size),
torch.zeros(1, 1, lstm_policy.hidden_size),
)
actions, value, state = lstm_policy(obs, state)
# Use our CleanRL API compatibility layerimport pufferlib.cleanrl
cleanrl_policy = pufferlib.cleanrl.Policy(policy)
actions = cleanrl_policy.get_action_and_value(obs)[0].numpy()
# Use our CleanRL LSTM API compatibility layer
cleanrl_lstm_policy = pufferlib.cleanrl.RecurrentPolicy(lstm_policy)
actions = cleanrl_lstm_policy.get_action_and_value(obs)[0].numpy()
obs, rewards, terminals, truncateds, infos = vecenv.step(actions)
vecenv.close()
Remember the unflatten operation in Emulation? Notice our usage of pufferlib.pytorch.nativize_dtype and pufferlib.pytorch.nativize_tensor to unpack structured data in the forward pass. You only need to worry about this if your environment has structured observation data.
Puffer Ocean: First Party Envs & API
Gymnasium and PettingZoo are great, but they cause fundamental performance limitations that cap enviroments to far below 1M steps/second:
They are single-environment formats, not vector formats. This means that the loop over environments is done in Python.
They require you to directly return observations, rather than writing into a shared buffer. This incurs redundant copy operations.
There is a different API for single-agent and multi-agent environments. This causes compatibility issues, and the multi-agent API has a lot of Python overhead.
PufferLib ships with Ocean, our first-party suite of high-performance environments. They are written with our native PufferEnv API, which eliminates traditional bottlenecks. For debugging new code, Ocean also includes specially designed sanity environments that train in seconds and will let you catch 90% of implementation bugs.
PufferEnv is a vector format that easily scales to several million steps per second. It is very similar to Gymnasium's VectorEnv, but with some elements preserved from PettingZoo to better support multi-agent envs without inconveniencing single-agent use. The Python binding is included below. The full source is available in ocean/squared.
'''A simple sample environment. Use this as a template for your own envs.'''import gymnasium
import numpy as np
import pufferlib
from pufferlib.ocean.squared.cy_squared import CySquared
classSquared(pufferlib.PufferEnv):
def__init__(self, num_envs=1, render_mode=None, size=11, buf=None):
self.single_observation_space = gymnasium.spaces.Box(low=0, high=1,
shape=(size*size,), dtype=np.uint8)
self.single_action_space = gymnasium.spaces.Discrete(5)
self.render_mode = render_mode
self.num_agents = num_envs
super().__init__(buf)
self.c_envs = CySquared(self.observations, self.actions,
self.rewards, self.terminals, num_envs, size)
defreset(self, seed=None):
self.c_envs.reset()
return self.observations, []
defstep(self, actions):
self.actions[:] = actions
self.c_envs.step()
episode_returns = self.rewards[self.terminals]
info = []
iflen(episode_returns) > 0:
info = [{
'reward': np.mean(episode_returns),
}]
return (self.observations, self.rewards,
self.terminals, self.truncations, info)
defrender(self):
self.c_envs.render()
defclose(self):
self.c_envs.close()
The key feature here is the allocation of data buffers for observations, rewards, etc. When multiprocessing, PufferLib will place these into a single shared memory tensor, so your data will be available on the main process immediately. In other words, your environment computes observations directly into shared memory with no redundant copies or slow Python glue. The base class is less than 100 lines. We suggest reading it just to see that there is no magic.
The logic for this environment is imported. It is written in C with an intermediate binding in Cython. This is not part of the API. You can implement environment logic in Python or any other language you like. The only hard requirement is that you write data into the provided self.observations, self.actions, self.rewards, self.terminals, and self.truncations buffers. Additionally, there are two soft requirements that we plan to revisit. First, PufferEnvs must handle multiple environment copies internally. This one is easy to eliminate, but you probably don't want Python running this loop anyways. Second, your environment must handle resets internally. This one is just easier for now, and we will revisit if it becomes an issue.
A key goal of PufferLib is to give you the maximum performance based on the amount of engineering effort you are willing to put into your environment. Here is a full list of options, from fastest to slowest:
Native PufferEnv in C: All env logic in C. Carefully written C++ or any other language with a good binding to Python is fine too, but we find C the easiest. All of our new Ocean envs are written this way. One nice benefit is that C compiles to WASM, so it is easy to run web demos.
Native PufferEnv in Cython: All env logic in Cython. This can be as fast as C but requires careful optimization. We provide this option because it is easy for researchers familiar with Python, and we use Cython as a binding layer for C envs anyways, so it is easy to port later.
Native PufferEnv in Python: Uses our native API but keeps all logic in Python. Because of the way we manage observation memory, this can still be much faster than Gymnasium/PettingZoo
Gymnasium/PettingZoo:: Write a standard pre-pufferlib environment and use our Gymnasium/PettingZoo emulation layer for fast parallel simulation.
Not using PufferLib: A return to the dark ages.
Third Party Environments
You can use any well-behaved environment with PufferLib via a 1-line wrapper. These are just the environments we have gotten around to binding manually. A lot of environments are not well behaved and subtly deviate from the Gymnasium/PettingZoo API, or they use obscure features that most libraries are not designed to handle. Our bindings just help clean this up a bit. Plus, we add any tricky system package dependencies to PufferTank. Feel free to PR new bindings or fixes for existing ones!
OpenAI Gym is the standard API for single-agent reinforcement learning environments. It also contains some built-in environments. We include Box2D in our registry.
Pokemon Red is one of the original Pokemon games for gameboy. This project uses the game as an environment for reinforcement learning. We are actively supporting development on this one!
PettingZoo is the standard API for multi-agent reinforcement learning environments. It also contains some built-in environments. We include Butterfly in our registry.
Arcade Learning Environment provides a Gym interface for classic Atari games. This is the most popular benchmark for reinforcement learning algorithms.
Minigrid is a 2D grid-world environment engine and a collection of builtin environments. The target is flexible and computationally efficient RL research.
MAgent is a platform for large-scale agent simulation.
Neural MMO is a massively multiagent environment for reinforcement learning. It combines large agent populations with high per-agent complexity and is the most actively maintained (by me) project on this list.
Procgen is a suite of arcade games for reinforcement learning with procedurally generated levels. It is one of the most computationally efficient environments on this list.
Nethack Learning Environment is a port of the classic game NetHack to the Gym API. It combines extreme complexity with high simulation efficiency.
MiniHack Learning Environment is a stripped down version of NetHack with support for level editing and custom procedural generation.
Crafter is a top-down 2D Minecraft clone for RL research. It provides pixel observations and relatively long time horizons.
GPUDrive GPUDrive is a GPU-accelerated, multi-agent driving simulator that runs at 1 million FPS.
Griddly is an extremely optimized platform for building reinforcement learning environments. It also includes a large suite of built-in environments.
Gym MicroRTS is a real time strategy engine for reinforcement learning research. The Java configuration is a bit finicky -- we're still debugging this.