from typing import Sequence from functools import partial import jax import jax.numpy as jnp import flax.linen as nn class DQNNet(nn.Module): features: Sequence[int] n_actions: int @nn.compact def __call__(self, x): initializer = nn.initializers.xavier_uniform() x = nn.relu( nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)( jnp.array(x, ndmin=4) / 255.0 ) ) x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x)) x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x)) x = x.reshape((x.shape[0], -1)) x = jnp.squeeze(x) for idx_layer in range(3, len(self.features)): x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x))) return nn.Dense(self.n_actions, kernel_init=initializer)(x) class QNetwork: def __init__(self, features: Sequence[int], n_actions: int) -> None: self.network = DQNNet(features, n_actions) @partial(jax.jit, static_argnames="self") def best_action(self, params, state: jnp.ndarray) -> jnp.int8: return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8)