|
from typing import Sequence |
|
from functools import partial |
|
import jax |
|
import jax.numpy as jnp |
|
import flax.linen as nn |
|
|
|
|
|
class DQNNet(nn.Module): |
|
features: Sequence[int] |
|
n_actions: int |
|
|
|
@nn.compact |
|
def __call__(self, x): |
|
initializer = nn.initializers.xavier_uniform() |
|
x = nn.relu( |
|
nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)( |
|
jnp.array(x, ndmin=4) / 255.0 |
|
) |
|
) |
|
x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)) |
|
x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)) |
|
x = x.reshape((x.shape[0], -1)) |
|
|
|
x = jnp.squeeze(x) |
|
|
|
for idx_layer in range(3, len(self.features)): |
|
x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x))) |
|
|
|
return nn.Dense(self.n_actions, kernel_init=initializer)(x) |
|
|
|
|
|
class QNetwork: |
|
def __init__(self, features: Sequence[int], n_actions: int) -> None: |
|
self.network = DQNNet(features, n_actions) |
|
|
|
@partial(jax.jit, static_argnames="self") |
|
def best_action(self, params, state: jnp.ndarray) -> jnp.int8: |
|
return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8) |
|
|