File size: 1,347 Bytes
83e5230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Sequence
from functools import partial
import jax
import jax.numpy as jnp
import flax.linen as nn


class DQNNet(nn.Module):
    features: Sequence[int]
    n_actions: int

    @nn.compact
    def __call__(self, x):
        initializer = nn.initializers.xavier_uniform()
        x = nn.relu(
            nn.Conv(features=self.features[0], kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(
                jnp.array(x, ndmin=4) / 255.0
            )
        )
        x = nn.relu(nn.Conv(features=self.features[1], kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x))
        x = nn.relu(nn.Conv(features=self.features[2], kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x))
        x = x.reshape((x.shape[0], -1))

        x = jnp.squeeze(x)

        for idx_layer in range(3, len(self.features)):
            x = nn.relu((nn.Dense(self.features[idx_layer], kernel_init=initializer)(x)))

        return nn.Dense(self.n_actions, kernel_init=initializer)(x)


class QNetwork:
    def __init__(self, features: Sequence[int], n_actions: int) -> None:
        self.network = DQNNet(features, n_actions)

    @partial(jax.jit, static_argnames="self")
    def best_action(self, params, state: jnp.ndarray) -> jnp.int8:
        return jnp.argmax(self.network.apply(params, state)).astype(jnp.int8)