Spaces:
Sleeping
Sleeping
""" | |
Overview: | |
Here is the behaviour cloning (BC) main entry for gfootball. | |
We first collect demo data using ``FootballKaggle5thPlaceModel``, then train the BC model | |
using the collected demo data, | |
and (optional) test accuracy in train dataset and test dataset of the trained BC model | |
""" | |
from copy import deepcopy | |
import os | |
from ding.entry import serial_pipeline_bc, collect_demo_data | |
from ding.config import read_config, compile_config | |
from ding.policy import create_policy | |
from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config | |
from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ | |
from dizoo.gfootball.model.bots.kaggle_5th_place_model import FootballKaggle5thPlaceModel | |
path = os.path.abspath(__file__) | |
dir_path = os.path.dirname(path) | |
# in gfootball env: 3000 transitions = one episode | |
# 3e5 transitions = 100 episode, The memory needs about 180G | |
seed = 0 | |
gfootball_bc_config.exp_name = 'gfootball_bc_kaggle5th_seed0' | |
demo_transitions = int(3e5) # key hyper-parameter | |
data_path_transitions = dir_path + f'/gfootball_kaggle5th_{demo_transitions}-demo-transitions.pkl' | |
""" | |
phase 1: collect demo data utilizing ``FootballKaggle5thPlaceModel`` | |
""" | |
train_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] | |
input_cfg = train_config | |
if isinstance(input_cfg, str): | |
cfg, create_cfg = read_config(input_cfg) | |
else: | |
cfg, create_cfg = input_cfg | |
create_cfg.policy.type = create_cfg.policy.type + '_command' | |
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) | |
football_kaggle_5th_place_model = FootballKaggle5thPlaceModel() | |
expert_policy = create_policy( | |
cfg.policy, model=football_kaggle_5th_place_model, enable_field=['learn', 'collect', 'eval', 'command'] | |
) | |
# collect expert demo data | |
state_dict = expert_policy.collect_mode.state_dict() | |
collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] | |
collect_demo_data( | |
collect_config, | |
seed=seed, | |
expert_data_path=data_path_transitions, | |
collect_count=demo_transitions, | |
model=football_kaggle_5th_place_model, | |
state_dict=state_dict, | |
) | |
""" | |
phase 2: BC training | |
""" | |
bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)] | |
bc_config[0].policy.learn.train_epoch = 1000 # key hyper-parameter | |
football_naive_q = FootballNaiveQ() | |
_, converge_stop_flag = serial_pipeline_bc( | |
bc_config, seed=seed, data_path=data_path_transitions, model=football_naive_q | |
) | |