What is stable baselines 3 (sb3)

I have just read about this new release. This is a complete rewrite of stable baselines 2, without any reference to tensorflow, and based on pytorch (>1.4+).

There is a lot of running implementations of RL algorithms, based on gym. A very good introduction in this blog entry

My installation

Standard installation

conda create --name stablebaselines3 python=3.7
conda activate stablebaselines3
pip install stable-baselines3[extra]
conda install -c conda-forge jupyter_contrib_nbextensions
conda install nb_conda
!conda list
# packages in environment at /home/explore/miniconda3/envs/stablebaselines3:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_pytorch_select           0.1                       cpu_0  
absl-py                   0.12.0                   pypi_0    pypi
atari-py                  0.2.6                    pypi_0    pypi
attrs                     20.3.0             pyhd3deb0d_0    conda-forge
backcall                  0.2.0              pyh9f0ad1d_0    conda-forge
backports                 1.0                        py_2    conda-forge
backports.functools_lru_cache 1.6.1                      py_0    conda-forge
blas                      1.0                         mkl  
bleach                    3.3.0              pyh44b312d_0    conda-forge
box2d                     2.3.10                   pypi_0    pypi
box2d-py                  2.3.8                    pypi_0    pypi
ca-certificates           2021.1.19            h06a4308_1  
cachetools                4.2.1                    pypi_0    pypi
certifi                   2020.12.5        py37h06a4308_0  
cffi                      1.14.5           py37h261ae71_0  
chardet                   4.0.0                    pypi_0    pypi
cloudpickle               1.6.0                    pypi_0    pypi
cudatoolkit               11.0.221             h6bb024c_0  
cycler                    0.10.0                   pypi_0    pypi
decorator                 4.4.2                      py_0    conda-forge
defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
entrypoints               0.3             pyhd8ed1ab_1003    conda-forge
fire                      0.4.0              pyh44b312d_0    conda-forge
freetype                  2.10.4               h5ab3b9f_0  
future                    0.18.2                   pypi_0    pypi
google-auth               1.28.0                   pypi_0    pypi
google-auth-oauthlib      0.4.3                    pypi_0    pypi
grpcio                    1.36.1                   pypi_0    pypi
gym                       0.18.0                   pypi_0    pypi
icu                       58.2              hf484d3e_1000    conda-forge
idna                      2.10                     pypi_0    pypi
importlib-metadata        3.7.3            py37h89c1867_0    conda-forge
intel-openmp              2019.4                      243  
ipykernel                 5.5.0            py37h888b3d9_1    conda-forge
ipython                   7.21.0           py37h888b3d9_0    conda-forge
ipython_genutils          0.2.0                      py_1    conda-forge
jedi                      0.18.0           py37h89c1867_2    conda-forge
jinja2                    2.11.3             pyh44b312d_0    conda-forge
jpeg                      9b                   h024ee3a_2  
jsonschema                3.2.0              pyhd8ed1ab_3    conda-forge
jupyter_client            6.1.12             pyhd8ed1ab_0    conda-forge
jupyter_contrib_core      0.3.3                      py_2    conda-forge
jupyter_contrib_nbextensions 0.5.1              pyhd8ed1ab_2    conda-forge
jupyter_core              4.7.1            py37h89c1867_0    conda-forge
jupyter_highlight_selected_word 0.2.0           py37h89c1867_1002    conda-forge
jupyter_latex_envs        1.4.6           pyhd8ed1ab_1002    conda-forge
jupyter_nbextensions_configurator 0.4.1            py37h89c1867_2    conda-forge
kiwisolver                1.3.1                    pypi_0    pypi
lcms2                     2.11                 h396b838_0  
ld_impl_linux-64          2.33.1               h53a641e_7  
libffi                    3.3                  he6710b0_2  
libgcc-ng                 9.1.0                hdf63c60_0  
libmklml                  2019.0.5                      0  
libpng                    1.6.37               hbc83047_0  
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libstdcxx-ng              9.1.0                hdf63c60_0  
libtiff                   4.2.0                h85742a9_0  
libuv                     1.40.0               h7b6447c_0  
libwebp-base              1.2.0                h27cfd23_0  
libxml2                   2.9.10               hb55368b_3  
libxslt                   1.1.34               hc22bd24_0  
lxml                      4.6.3            py37h9120a33_0  
lz4-c                     1.9.3                h2531618_0  
markdown                  3.3.4                    pypi_0    pypi
markupsafe                1.1.1            py37hb5d75c8_2    conda-forge
matplotlib                3.3.4                    pypi_0    pypi
mistune                   0.8.4           py37h4abf009_1002    conda-forge
mkl                       2020.2                      256  
mkl-service               2.3.0            py37he8ac12f_0  
mkl_fft                   1.3.0            py37h54f3939_0  
mkl_random                1.1.1            py37h0573a6f_0  
nb_conda                  2.2.1                    py37_0  
nb_conda_kernels          2.3.1            py37h06a4308_0  
nbconvert                 5.6.1            py37hc8dfbb8_1    conda-forge
nbformat                  5.1.2              pyhd8ed1ab_1    conda-forge
ncurses                   6.2                  he6710b0_1  
ninja                     1.10.2           py37hff7bd54_0  
notebook                  5.7.10           py37hc8dfbb8_0    conda-forge
numpy                     1.20.1                   pypi_0    pypi
numpy-base                1.19.2           py37hfa32c7d_0  
oauthlib                  3.1.0                    pypi_0    pypi
olefile                   0.46                     py37_0  
opencv-python             4.5.1.48                 pypi_0    pypi
openssl                   1.1.1j               h27cfd23_0  
packaging                 20.9               pyh44b312d_0    conda-forge
pandas                    1.2.3                    pypi_0    pypi
pandoc                    2.12                 h7f98852_0    conda-forge
pandocfilters             1.4.2                      py_1    conda-forge
parso                     0.8.1              pyhd8ed1ab_0    conda-forge
pexpect                   4.8.0              pyh9f0ad1d_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    7.2.0                    pypi_0    pypi
pip                       21.0.1           py37h06a4308_0  
prometheus_client         0.9.0              pyhd3deb0d_0    conda-forge
prompt-toolkit            3.0.18             pyha770c72_0    conda-forge
protobuf                  3.15.6                   pypi_0    pypi
psutil                    5.8.0                    pypi_0    pypi
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pycparser                 2.20                       py_2  
pyglet                    1.5.0                    pypi_0    pypi
pygments                  2.8.1              pyhd8ed1ab_0    conda-forge
pyparsing                 2.4.7              pyh9f0ad1d_0    conda-forge
pyrsistent                0.17.3           py37h4abf009_1    conda-forge
python                    3.7.10               hdb3f193_0  
python-dateutil           2.8.1                      py_0    conda-forge
python_abi                3.7                     1_cp37m    conda-forge
pytorch                   1.7.1           py3.7_cuda11.0.221_cudnn8.0.5_0    pytorch
pytz                      2021.1                   pypi_0    pypi
pyyaml                    5.3.1            py37hb5d75c8_1    conda-forge
pyzmq                     19.0.2           py37hac76be4_2    conda-forge
readline                  8.1                  h27cfd23_0  
requests                  2.25.1                   pypi_0    pypi
requests-oauthlib         1.3.0                    pypi_0    pypi
rsa                       4.7.2                    pypi_0    pypi
scipy                     1.6.1                    pypi_0    pypi
send2trash                1.5.0                      py_0    conda-forge
setuptools                52.0.0           py37h06a4308_0  
six                       1.15.0             pyh9f0ad1d_0    conda-forge
sqlite                    3.35.2               hdfb4753_0  
stable-baselines3         1.0                      pypi_0    pypi
tensorboard               2.4.1                    pypi_0    pypi
tensorboard-plugin-wit    1.8.0                    pypi_0    pypi
termcolor                 1.1.0            py37h06a4308_1  
terminado                 0.9.3            py37h89c1867_0    conda-forge
testpath                  0.4.4                      py_0    conda-forge
tk                        8.6.10               hbc83047_0  
torchaudio                0.7.2                      py37    pytorch
torchvision               0.8.2                py37_cu110    pytorch
tornado                   6.1              py37h4abf009_0    conda-forge
traitlets                 5.0.5                      py_0    conda-forge
typing-extensions         3.7.4.3                       0  
typing_extensions         3.7.4.3                    py_0    conda-forge
urllib3                   1.26.4                   pypi_0    pypi
wcwidth                   0.2.5              pyh9f0ad1d_2    conda-forge
webencodings              0.5.1                      py_1    conda-forge
werkzeug                  1.0.1                    pypi_0    pypi
wheel                     0.36.2             pyhd3eb1b0_0  
xz                        5.2.5                h7b6447c_0  
yaml                      0.2.5                h516909a_0    conda-forge
zeromq                    4.3.4                h2531618_0  
zipp                      3.4.1              pyhd8ed1ab_0    conda-forge
zlib                      1.2.11               h7b6447c_3  
zstd                      1.4.5                h9ceee32_0  

SB3 tutorials

import gym

from stable_baselines3 import A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback, EvalCallback

# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(save_freq=5000, save_path="/home/explore/git/guillaume/stable_baselines_3/logs/",
                                         name_prefix="rl_model")

# Evaluate the model periodically
# and auto-save the best model and evaluations
# Use a monitor wrapper to properly report episode stats
eval_env = Monitor(gym.make("LunarLander-v2"))
# Use deterministic actions for evaluation
eval_callback = EvalCallback(eval_env, best_model_save_path="/home/explore/git/guillaume/stable_baselines_3/logs/",
                             log_path="/home/explore/git/guillaume/stable_baselines_3/logs/", eval_freq=2000,
                             deterministic=True, render=False)

# Train an agent using A2C on LunarLander-v2
model = A2C("MlpPolicy", "LunarLander-v2", verbose=1)
model.learn(total_timesteps=20000, callback=[checkpoint_callback, eval_callback])

# Retrieve and reset the environment
env = model.get_env()
obs = env.reset()

# Query the agent (stochastic action here)
action, _ = model.predict(obs, deterministic=False)

Issues and fix

CUDA error: CUBLAS_STATUS_INTERNAL_ERROR

Downgrade pytorch to 1.7.1

to avoid RuntimeError: CUDA error: CUBLAS_STATUS_INTERNAL_ERROR when calling cublasCreate(handle)

pip install torch==1.7.1

RuntimeError: CUDA error: invalid device function

!nvidia-smi
Thu Mar 25 09:13:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 4000     Off  | 00000000:01:00.0  On |                  N/A |
| N/A   41C    P5    18W /  N/A |   2104MiB /  7982MiB |     32%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1153      G   /usr/lib/xorg/Xorg                162MiB |
|    0   N/A  N/A      1904      G   /usr/lib/xorg/Xorg                268MiB |
|    0   N/A  N/A      2076      G   /usr/bin/gnome-shell              403MiB |
|    0   N/A  N/A      2697      G   ...gAAAAAAAAA --shared-files       54MiB |
|    0   N/A  N/A      7220      G   ...AAAAAAAAA= --shared-files       84MiB |
|    0   N/A  N/A     57454      G   /usr/lib/firefox/firefox            2MiB |
|    0   N/A  N/A     59274      C   ...ablebaselines3/bin/python     1051MiB |
+-----------------------------------------------------------------------------+

CUDA version is 11.0 on my workstation.

!nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243
!conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=11.0 -c pytorch
Collecting package metadata (current_repodata.json): done
Solving environment: done

# All requested packages already installed.

Everything seems fine after these updates.

Stable baselines 3 user guide

There is an impressive documentation associated with stable baselines 3. Quickstart

Tips and tricks

This page covers general advice about RL (where to start, which algorithm to choose, how to evaluate an algorithm, …), as well as tips and tricks when using a custom environment or implementing an RL algorithm.

Tune hyperparameters RL zoo is introduced. It contains some hyperparameter optimization.

RL evaluation We suggest you reading Deep Reinforcement Learning that Matters for a good discussion about RL evaluation.

which algorithm to choose 1st criteria is discrete vs continuous actions. And 2nd is capacity to parallelize training.

Discrete Actions

  • Discrete Actions - Single Process

DQN with extensions (double DQN, prioritized replay, …) are the recommended algorithms. We notably provide QR-DQN in our contrib repo. DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).

  • Discrete Actions - Multiprocessed

You should give a try to PPO or A2C.

Continuous Actions

  • Continuous Actions - Single Process

Current State Of The Art (SOTA) algorithms are SAC, TD3 and TQC (available in our contrib repo). Please use the hyperparameters in the RL zoo for best results.

  • Continuous Actions - Multiprocessed

Take a look at PPO or A2C. Again, don’t forget to take the hyperparameters from the RL zoo for continuous actions problems (cf Bullet envs).

Creating a custom env

multiple times there are advices about normalizing: observation and action space. A good practice is to rescale your actions to lie in [-1, 1]. This does not limit you as you can easily rescale the action inside the environment

tips and tricks to reproduce a RL paper

Reinforcement Learning Tips and Tricks — Stable Baselines3 1.1.0a1 documentation

A personal pick (by @araffin) for environments with gradual difficulty in RL with continuous actions:> > 1. Pendulum (easy to solve)

  1. HalfCheetahBullet (medium difficulty with local minima and shaped reward)

  2. BipedalWalkerHardcore (if it works on that one, then you can have a cookie)

in RL with discrete actions:> > 1. CartPole-v1 (easy to be better than random agent, harder to achieve maximal performance)

  1. LunarLander

  2. Pong (one of the easiest Atari game)

  3. other Atari games (e.g. Breakout)

Resource page

Reinforcement Learning Resources — Stable Baselines3 1.1.0a1 documentation

Stable-Baselines3 assumes that you already understand the basic concepts of Reinforcement Learning (RL).

However, if you want to learn about RL, there are several good resources to get started:

Examples

I will run these examples in 01 -hands-on.ipynb from handson_stablebaselines3

DQN lunarlander

My module is never landing :(

Note: animated gif created with peek.

PPO with multiprocessing cartpole

Monitor training using callback

This could be useful when you want to monitor training, for instance display live learning curves in Tensorboard (or in Visdom) or save the best agent.

Atari game such as pong (A2C with 6 envt) or breakout

Here the list of valid gym atari environments: https://gym.openai.com/envs/#atari

sb3_breakout.gif

pybullet

This is a SDK to real-time collision detection and multi-physics simulation for VR, games, visual effects, robotics, machine learning etc.

https://github.com/bulletphysics/bullet3/

We need to install it: pip install pybullet

I don't have rendering capacity when playing with it. Because robotic is far from my need, I will skip on this one

Hindsight Experience Replay (HER)

using Highway-Env

installation with pip install highway-env

After 1h15m of training, some 1st results:

And after that some technical stuff such as:

  • Learning Rate Schedule: start with high value and reduce it as learning goes
  • Advanced Saving and Loading: how to easily create a test environment to evaluate an agent periodically, use a policy independently from a model (and how to save it, load it) and save/load a replay buffer.
  • Accessing and modifying model parameters: These functions are useful when you need to e.g. evaluate large set of models with same network structure, visualize different layers of the network or modify parameters manually.
  • Record a video or make a gif

Make a GIF of a Trained Agent

pip install imageio

and this time the lander is getting closer to moon but not at all between flags.