Atari_EauDeQN / networks.py
TheoVincent's picture
main files
83e5230
raw
history blame
1.35 kB
from typing import Sequence
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn
class DQNNet(nn.Module):
features: Sequence[int]
n_actions: int
@nn.compact
def __call__(self, x):
initializer = nn.initializers.xavier_uniform()
x = nn.relu(
nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(
jnp.array(x, ndmin=4) / 255.0
)
)
x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x))
x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x))
x = x.reshape((x.shape[0], -1))
x = jnp.squeeze(x)
for idx_layer in range(3, len(self.features)):
x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x)))
return nn.Dense(self.n_actions, kernel_init=initializer)(x)
class QNetwork:
def __init__(self, features: Sequence[int], n_actions: int) -> None:
self.network = DQNNet(features, n_actions)
@partial(jax.jit, static_argnames="self")
def best_action(self, params, state: jnp.ndarray) -> jnp.int8:
return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8)