jmercat commited on
Commit
5769ee4
·
0 Parent(s):

Removed history to avoid any unverified information being released

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +31 -0
  2. .gitignore +14 -0
  3. README.md +93 -0
  4. app.py +3 -0
  5. export_waymo_to_json.py +94 -0
  6. image/illustration.png +0 -0
  7. import_dataset_from_huggingface.py +55 -0
  8. import_model_from_huggingface.py +16 -0
  9. notebooks/visualize_planner_evaluation_results.ipynb +344 -0
  10. requirements.txt +16 -0
  11. risk_biased/__init__.py +0 -0
  12. risk_biased/config/learning_config.py +156 -0
  13. risk_biased/config/paths.py +9 -0
  14. risk_biased/config/planning_config.py +13 -0
  15. risk_biased/config/waymo_config.py +104 -0
  16. risk_biased/models/__init__.py +0 -0
  17. risk_biased/models/biased_cvae_model.py +907 -0
  18. risk_biased/models/context_gating.py +53 -0
  19. risk_biased/models/cvae_decoder.py +388 -0
  20. risk_biased/models/cvae_encoders.py +376 -0
  21. risk_biased/models/cvae_params.py +78 -0
  22. risk_biased/models/latent_distributions.py +468 -0
  23. risk_biased/models/map_encoder.py +38 -0
  24. risk_biased/models/mlp.py +60 -0
  25. risk_biased/models/multi_head_attention.py +81 -0
  26. risk_biased/models/nn_blocks.py +626 -0
  27. risk_biased/mpc_planner/__init__.py +0 -0
  28. risk_biased/mpc_planner/dynamics.py +49 -0
  29. risk_biased/mpc_planner/planner.py +332 -0
  30. risk_biased/mpc_planner/planner_cost.py +127 -0
  31. risk_biased/mpc_planner/solver.py +429 -0
  32. risk_biased/predictors/biased_predictor.py +568 -0
  33. risk_biased/scene_dataset/__init__.py +0 -0
  34. risk_biased/scene_dataset/loaders.py +252 -0
  35. risk_biased/scene_dataset/pedestrian.py +165 -0
  36. risk_biased/scene_dataset/scene.py +522 -0
  37. risk_biased/scene_dataset/scene_plotter.py +276 -0
  38. risk_biased/utils/__init__.py +0 -0
  39. risk_biased/utils/callbacks.py +595 -0
  40. risk_biased/utils/config_argparse.py +96 -0
  41. risk_biased/utils/cost.py +539 -0
  42. risk_biased/utils/load_model.py +220 -0
  43. risk_biased/utils/loss.py +124 -0
  44. risk_biased/utils/metrics.py +81 -0
  45. risk_biased/utils/planner_utils.py +462 -0
  46. risk_biased/utils/risk.py +571 -0
  47. risk_biased/utils/torch_utils.py +66 -0
  48. risk_biased/utils/waymo_dataloader.py +490 -0
  49. scripts/eval_scripts/compute_stats.py +523 -0
  50. scripts/eval_scripts/draw_cost.py +84 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/*
2
+ */__pycache__/*
3
+ **/__pycache__/*
4
+ scene_data_*.npy
5
+ scripts/logs/*
6
+ logs/*
7
+ wandb/*
8
+ data/*
9
+
10
+ *~*
11
+ *#*
12
+ *sweep_logs*
13
+ *.ipynb_checkpoints*
14
+ *.egg-info*
README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: "RAP: Risk-Aware Prediction"
3
+ emoji: 🚙
4
+ colorFrom: red
5
+ colorTo: grey
6
+ sdk: gradio
7
+ sdk_version: 3.7
8
+ app_file: app.py
9
+ pinned: false
10
+ language:
11
+ - Python
12
+ thumbnail: "url to a thumbnail used in social sharing"
13
+ tags:
14
+ - Risk Measures
15
+ - Forecasting
16
+ - Safety
17
+ - Human-Robot Interaction
18
+ license: cc-by-nc-4.0
19
+
20
+ ---
21
+
22
+ # License statement
23
+
24
+ The code is provided under a Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. Under the license, the code is provided royalty free for non-commercial purposes only. The code may be covered by patents and if you want to use the code for commercial purposes, please contact us for a different license.
25
+
26
+ # RAP: Risk-Aware Prediction
27
+
28
+ This is the official code for [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368). You can test the results in [our huggingface demo](https://huggingface.co/spaces/TRI-ML/risk_biased_prediction) and see some additional experiments on the [paper website](https://sites.google.com/view/corl-risk/).
29
+
30
+ ![A planner reacts to low-probability events if they are dangerous, biasing the predictions to better represent these events helps the planner to be cautious.](image/illustration.png)
31
+
32
+ We define and train a trajectory forecasting model and bias its prediction towards risk such that it helps a planner to estimate risk by producing the relevant pessimistic trajectory forecasts to consider.
33
+
34
+ ## Datasets
35
+ This repository uses two datasets:
36
+ - A didactic simulated environement with a single vehicle at constant velocity and a single pedestrian.
37
+ Two pedestrian behavior are implemented: fast and slow. At each step, pedestrians might walk at their favored speed or at the other speed.
38
+ This produces a distribution of pedestrian trajectories with two modes. The dataset is automatically generated and used. You can change the parameters of the data generation in "config/learning_config.py"
39
+ - The Waymo Open Motion Dataset (WOMD) with complex real scenes.
40
+
41
+
42
+ ## Forecasting model
43
+ A conditional variational auto-encoder (CVAE) model is used as the base pedestrian trajectory predictor. Its latent space is quantized or gaussian depending on the parameter that you set in the config. It uses either multi-head attention or a modified version of context gating to account for interactions. Depending on the parameters, the trajectory encoder and decoder can be set to MLP, LSTM, or maskedLSTM.
44
+
45
+ # Usage
46
+
47
+ ## Installation
48
+
49
+ - (Set up a virtual environment with python>3.7)
50
+ - Install the packge with `pip -e install .`
51
+
52
+ ## Setting up the data
53
+
54
+ ### Didactic simulation
55
+ - The dataset is automatically generated and used. You can change the parameters of the data generation in "config/learning_config.py"
56
+
57
+ ### WOMD
58
+ - [Download the Waymo Open Motion Dataset (WOMD)](https://waymo.com/open/)
59
+ - Pre-process it as follows:
60
+ - Sample set: `python scripts/scripts_utils/generate_dataset_waymo.py <data/Waymo>/scenario/validation <data/Waymo>/interactive_veh_type/sample --num_parallel=<16> --debug_size=<1000>`
61
+ - Training set: `python scripts/interaction_utils/generate_dataset_waymo.py <data/Waymo>/scenario/training <data/Waymo>/interactive_veh_type/training --num_parallel=<16>`
62
+ - Validation set: `python scripts/interaction_utils/generate_dataset_waymo.py <data/Waymo>/scenario/validation_interactive <data/Waymo>/interactive_veh_type/validation --num_parallel=<16>`
63
+
64
+ Replace the arguments:
65
+ - `<data/Waymo>` with the path where you downloaded WOMD
66
+ - `<16>` with the number of cores you want to use
67
+ - `<1000>` with the number of scene to process for the sample set (some scenes are filtered out so the resulting number of pre-processed scenes might be about the third of the input number)
68
+ - Set up the path to the dataset in "risk_biased/config/paths.py"
69
+
70
+ ## Configuration and training
71
+
72
+ - Set up the output log path in "risk_biased/config/paths.py"
73
+ - You might need to login to wandb with `wandb login <option> <key>...`
74
+ - All the parameters defined in "risk_biased/config/learning_config.py" or "risk_biased/config/waymo_config.py" can be overwritten with a command line argument.
75
+ - To start from a WandB checkpoint, use the option `--load_from "<wandb id>"`. If you wish to force the usage of the local configuration instead of the checkpoint configuration, add the option `--force_config`. If you want to load the last checkpoint instead of the best one, add the option `--load_last`.
76
+
77
+ ### Didactic simulation
78
+ - Choose the parameters to set in "risk_biased/config/learning_config.py"
79
+ - Start training: `python scripts/train_didactic.py`
80
+
81
+ ### WOMD
82
+ - Choose the parameters to set in "risk_biased/config/waymo_config.py"
83
+ - Start traning: `python scripts/train_interaction.py`
84
+
85
+ Training has two phases: training the unbiased predictor then training the biased encoder with a frozen predictor model. This second step need to draw many samples to estimate the risk. It is possible that your GPU runs out of memory at this stage. If it does consider reducing the batch size and reducing the number of samples "n_mc_samples_biased". If the number of samples "n_mc_samples_risk" is kept high, the risk estimation will be more accurate but training might be very slow.
86
+
87
+ ## Evaluation
88
+
89
+ Many evaluation scripts are available in "scripts/eval_scripts", to compute results, plot graphs, draw the didactic experiment scene etc...
90
+
91
+ You can also run the interactive interface locally with `python scripts/scripts_utils/plotly_interface.py --load_from=<full path to the .ckpt checkpoint file> --cfg_path=<full path to the learning_config.py file from the checkpoint>`
92
+ Sadly the WOMD license does not allow us to provide the pre-trained weights of our model so you will need to train it yourself.
93
+
app.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from scripts.scripts_utils.plotly_interface import main
2
+
3
+ main()
export_waymo_to_json.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from json import JSONEncoder
3
+ from mmcv import Config
4
+ import numpy
5
+ import torch
6
+
7
+ from risk_biased.utils.waymo_dataloader import WaymoDataloaders
8
+
9
+
10
+ class NumpyArrayEncoder(JSONEncoder):
11
+ def default(self, obj):
12
+ if isinstance(obj, numpy.ndarray):
13
+ return obj.tolist()
14
+ return JSONEncoder.default(self, obj)
15
+
16
+ if __name__ == "__main__":
17
+ output_path = "../risk_biased_dataset/data.json"
18
+ config_path = "risk_biased/config/waymo_config.py"
19
+ cfg = Config.fromfile(config_path)
20
+ dataloaders = WaymoDataloaders(cfg)
21
+ sample_dataloader = dataloaders.sample_dataloader()
22
+ (
23
+ x,
24
+ mask_x,
25
+ y,
26
+ mask_y,
27
+ mask_loss,
28
+ map_data,
29
+ mask_map,
30
+ offset,
31
+ x_ego,
32
+ y_ego,
33
+ ) = sample_dataloader.collate_fn(sample_dataloader.dataset)
34
+
35
+ batch_size, n_agents, n_timesteps_past, n_features = x.shape
36
+ n_timesteps_future = y.shape[2]
37
+ n_features_map = map_data.shape[3]
38
+ n_features_offset = offset.shape[2]
39
+
40
+ print(x.shape)
41
+ print(mask_x.shape)
42
+ print(y.shape)
43
+ print(mask_y.shape)
44
+ print(mask_loss.shape)
45
+ print(map_data.shape)
46
+ print(mask_map.shape)
47
+ print(offset.shape)
48
+ print(x_ego.shape)
49
+ print(y_ego.shape)
50
+
51
+
52
+ data = {"x": x.numpy(),
53
+ "mask_x": mask_x.numpy(),
54
+ "y": y.numpy(),
55
+ "mask_y": mask_y.numpy(),
56
+ "mask_loss": mask_loss.numpy(),
57
+ "map_data": map_data.numpy(),
58
+ "mask_map": mask_map.numpy(),
59
+ "offset": offset.numpy(),
60
+ "x_ego": x_ego.numpy(),
61
+ "y_ego": y_ego.numpy(),
62
+ }
63
+
64
+ json_data = json.dumps(data, cls=NumpyArrayEncoder)
65
+
66
+ with open(output_path, "w+") as f:
67
+ f.write(json_data)
68
+
69
+ with open(output_path, "r") as f:
70
+ decoded = json.load(f)
71
+
72
+ x_c = torch.from_numpy(numpy.array(decoded["x"]).astype(numpy.float32))
73
+ mask_x_c = torch.from_numpy(numpy.array(decoded["mask_x"]).astype(numpy.bool8))
74
+ y_c = torch.from_numpy(numpy.array(decoded["y"]).astype(numpy.float32))
75
+ mask_y_c = torch.from_numpy(numpy.array(decoded["mask_y"]).astype(numpy.bool8))
76
+ mask_loss_c = torch.from_numpy( numpy.array(decoded["mask_loss"]).astype(numpy.bool8))
77
+ map_data_c = torch.from_numpy(numpy.array(decoded["map_data"]).astype(numpy.float32))
78
+ mask_map_c = torch.from_numpy(numpy.array(decoded["mask_map"]).astype(numpy.bool8))
79
+ offset_c = torch.from_numpy(numpy.array(decoded["offset"]).astype(numpy.float32))
80
+ x_ego_c = torch.from_numpy(numpy.array(decoded["x_ego"]).astype(numpy.float32))
81
+ y_ego_c = torch.from_numpy(numpy.array(decoded["y_ego"]).astype(numpy.float32))
82
+
83
+ assert torch.allclose(x, x_c)
84
+ assert torch.allclose(mask_x, mask_x_c)
85
+ assert torch.allclose(y, y_c)
86
+ assert torch.allclose(mask_y, mask_y_c)
87
+ assert torch.allclose(mask_loss, mask_loss_c)
88
+ assert torch.allclose(map_data, map_data_c)
89
+ assert torch.allclose(mask_map, mask_map_c)
90
+ assert torch.allclose(offset, offset_c)
91
+ assert torch.allclose(x_ego, x_ego_c)
92
+ assert torch.allclose(y_ego, y_ego_c)
93
+
94
+ print("All good!")
image/illustration.png ADDED
import_dataset_from_huggingface.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import datasets
3
+ import json
4
+ from mmcv import Config
5
+ import numpy
6
+ import torch
7
+
8
+ from risk_biased.utils.waymo_dataloader import WaymoDataloaders
9
+
10
+
11
+ config_path = "risk_biased/config/waymo_config.py"
12
+ cfg = Config.fromfile(config_path)
13
+ dataloaders = WaymoDataloaders(cfg)
14
+ sample_dataloader = dataloaders.sample_dataloader()
15
+ (
16
+ x,
17
+ mask_x,
18
+ y,
19
+ mask_y,
20
+ mask_loss,
21
+ map_data,
22
+ mask_map,
23
+ offset,
24
+ x_ego,
25
+ y_ego,
26
+ ) = sample_dataloader.collate_fn(sample_dataloader.dataset)
27
+
28
+ # dataset = load_dataset("json", data_files="../risk_biased_dataset/data.json", split="test", field="x")
29
+ # dataset = load_from_disk("../risk_biased_dataset/data.json")
30
+ dataset = load_dataset("jmercat/risk_biased_dataset", split="test")
31
+
32
+ x_c = torch.from_numpy(numpy.array(dataset["x"]).astype(numpy.float32))
33
+ mask_x_c = torch.from_numpy(numpy.array(dataset["mask_x"]).astype(numpy.bool8))
34
+ y_c = torch.from_numpy(numpy.array(dataset["y"]).astype(numpy.float32))
35
+ mask_y_c = torch.from_numpy(numpy.array(dataset["mask_y"]).astype(numpy.bool8))
36
+ mask_loss_c = torch.from_numpy( numpy.array(dataset["mask_loss"]).astype(numpy.bool8))
37
+ map_data_c = torch.from_numpy(numpy.array(dataset["map_data"]).astype(numpy.float32))
38
+ mask_map_c = torch.from_numpy(numpy.array(dataset["mask_map"]).astype(numpy.bool8))
39
+ offset_c = torch.from_numpy(numpy.array(dataset["offset"]).astype(numpy.float32))
40
+ x_ego_c = torch.from_numpy(numpy.array(dataset["x_ego"]).astype(numpy.float32))
41
+ y_ego_c = torch.from_numpy(numpy.array(dataset["y_ego"]).astype(numpy.float32))
42
+
43
+ assert torch.allclose(x, x_c)
44
+ assert torch.allclose(mask_x, mask_x_c)
45
+ assert torch.allclose(y, y_c)
46
+ assert torch.allclose(mask_y, mask_y_c)
47
+ assert torch.allclose(mask_loss, mask_loss_c)
48
+ assert torch.allclose(map_data, map_data_c)
49
+ assert torch.allclose(mask_map, mask_map_c)
50
+ assert torch.allclose(offset, offset_c)
51
+ assert torch.allclose(x_ego, x_ego_c)
52
+ assert torch.allclose(y_ego, y_ego_c)
53
+
54
+ print("All good!")
55
+
import_model_from_huggingface.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_url, cached_download
2
+ from mmcv import Config
3
+ import torch
4
+
5
+ from risk_biased.utils.load_model import get_predictor
6
+ from risk_biased.utils.torch_utils import load_weights
7
+ from risk_biased.utils.waymo_dataloader import WaymoDataloaders
8
+
9
+
10
+ config_file = cached_download(hf_hub_url("jmercat/risk_biased_model", filename="learning_config.py"), force_filename="learing_config.py")
11
+ ckpt = torch.load(cached_download(hf_hub_url("jmercat/risk_biased_model", filename="last.ckpt"), force_filename="last.ckpt"), map_location="cpu")
12
+ cfg = Config.fromfile(config_file)
13
+ predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
14
+ predictor = load_weights(predictor, ckpt)
15
+
16
+ print("Model loaded")
notebooks/visualize_planner_evaluation_results.ipynb ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a0d39c3c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%matplotlib notebook\n",
11
+ "# Switch to inline if debugging plotting.\n",
12
+ "# %matplotlib inline\n",
13
+ "\n",
14
+ "import os\n",
15
+ "import pickle\n",
16
+ "import sys\n",
17
+ "\n",
18
+ "from matplotlib import animation\n",
19
+ "import matplotlib.pyplot as plt\n",
20
+ "import numpy as np"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "id": "edc5a0af",
26
+ "metadata": {},
27
+ "source": [
28
+ "### Load evaluation results."
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 2,
34
+ "id": "569e209d",
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "# Replace with the directory containing evaluation outputs.\n",
39
+ "stats_dir = \"../scripts/eval_scripts/logs/planner_eval/run-1rdonjl7_0/\"\n",
40
+ "scene_type = \"safer_slow\"\n",
41
+ "risk_level = 1.0\n",
42
+ "num_samples = 256\n",
43
+ "\n",
44
+ "filepath = os.path.join(stats_dir, f\"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_predictor.pkl\")\n",
45
+ "with open(filepath, \"rb\") as infile:\n",
46
+ " predictor_data = pickle.load(infile)\n",
47
+ "filepath = os.path.join(stats_dir, f\"{scene_type}_{num_samples}_samples_risk_level_{risk_level}_in_planner.pkl\")\n",
48
+ "with open(filepath, \"rb\") as infile:\n",
49
+ " planner_data = pickle.load(infile)"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "id": "6cdbb58a",
55
+ "metadata": {},
56
+ "source": [
57
+ "### Find the most relevant episodes."
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 3,
63
+ "id": "49b8859c",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "def print_riskiest_episodes(data, max_ep_id=10000, num_to_print=10, key=\"interaction_risk\"):\n",
68
+ " episode_id_risk = []\n",
69
+ " for episode_id in range(max_ep_id):\n",
70
+ " if episode_id not in data:\n",
71
+ " break\n",
72
+ " episode_id_risk.append((episode_id, data[episode_id][key]))\n",
73
+ " episode_id_risk = sorted(episode_id_risk, key=lambda x: x[1], reverse=True)\n",
74
+ " \n",
75
+ " print(\"Riskiest episodes:\")\n",
76
+ " for ep_id, risk in episode_id_risk[:num_to_print]:\n",
77
+ " print(f\"episode id: {ep_id}\\trisk: {risk:0.4f}\")\n",
78
+ "\n",
79
+ "def print_largest_risk_difference_episodes(\n",
80
+ " data_a, \n",
81
+ " data_b, \n",
82
+ " max_ep_id=10000, \n",
83
+ " key=\"interaction_risk\", \n",
84
+ " num_to_print=10):\n",
85
+ " \"\"\"Plots the episodes where the risk of a is most larger than that of b.\"\"\"\n",
86
+ " risk_a = []\n",
87
+ " risk_b = []\n",
88
+ " for episode_id in range(max_ep_id):\n",
89
+ " if episode_id not in data_a or episode_id not in data_b:\n",
90
+ " break\n",
91
+ " risk_a.append(data_a[episode_id][key])\n",
92
+ " risk_b.append(data_b[episode_id][key])\n",
93
+ " risk_a = np.array(risk_a)\n",
94
+ " risk_b = np.array(risk_b)\n",
95
+ " \n",
96
+ " diff = risk_a - risk_b\n",
97
+ " indices = np.argsort(diff)[::-1]\n",
98
+ " \n",
99
+ " print(\"Episdoes where the first data is risker than the second\")\n",
100
+ " for episode_id in indices[:num_to_print]:\n",
101
+ " print(f\"episode_id: {episode_id}\\tfirst data risk: {risk_a[episode_id]:0.4f}\\tsecond data risk: {risk_b[episode_id]:0.4f}\")"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 4,
107
+ "id": "1b7de380",
108
+ "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "name": "stdout",
112
+ "output_type": "stream",
113
+ "text": [
114
+ "Riskiest episodes:\n",
115
+ "episode id: 66\trisk: 1.1713\n",
116
+ "episode id: 48\trisk: 1.0506\n",
117
+ "episode id: 32\trisk: 1.0245\n",
118
+ "episode id: 37\trisk: 1.0139\n",
119
+ "episode id: 39\trisk: 0.9978\n",
120
+ "episode id: 84\trisk: 0.9266\n",
121
+ "episode id: 24\trisk: 0.9190\n",
122
+ "episode id: 79\trisk: 0.8993\n",
123
+ "episode id: 75\trisk: 0.8989\n",
124
+ "episode id: 71\trisk: 0.8468\n"
125
+ ]
126
+ }
127
+ ],
128
+ "source": [
129
+ "print_riskiest_episodes(predictor_data)"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": 6,
135
+ "id": "ba9bc377",
136
+ "metadata": {},
137
+ "outputs": [
138
+ {
139
+ "name": "stdout",
140
+ "output_type": "stream",
141
+ "text": [
142
+ "Episdoes where the first data is risker than the second\n",
143
+ "episode_id: 54\tfirst data risk: 1.2152\tsecond data risk: 1.0733\n",
144
+ "episode_id: 82\tfirst data risk: 0.5233\tsecond data risk: 0.3918\n",
145
+ "episode_id: 30\tfirst data risk: 1.2241\tsecond data risk: 1.0988\n",
146
+ "episode_id: 29\tfirst data risk: 0.8171\tsecond data risk: 0.7078\n",
147
+ "episode_id: 48\tfirst data risk: 1.0183\tsecond data risk: 0.9133\n",
148
+ "episode_id: 18\tfirst data risk: 0.4989\tsecond data risk: 0.4065\n",
149
+ "episode_id: 95\tfirst data risk: 0.5201\tsecond data risk: 0.4409\n",
150
+ "episode_id: 4\tfirst data risk: 0.8029\tsecond data risk: 0.7406\n",
151
+ "episode_id: 23\tfirst data risk: 0.9542\tsecond data risk: 0.8920\n",
152
+ "episode_id: 72\tfirst data risk: 1.0446\tsecond data risk: 0.9828\n"
153
+ ]
154
+ }
155
+ ],
156
+ "source": [
157
+ "print_largest_risk_difference_episodes(predictor_data, planner_data)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "markdown",
162
+ "id": "5fc3cfc3",
163
+ "metadata": {},
164
+ "source": [
165
+ "### Animate those episodes for the risk-sensitive predictor vs risk-sensitive planner."
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 7,
171
+ "id": "c03ae0f0",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "def animate_episode(solver_infos, ado_positions, ado_predictions):\n",
176
+ " fig = plt.figure()\n",
177
+ " ax = plt.axes(xlim=(0, 100), ylim=(-4, 8))\n",
178
+ " \n",
179
+ " scatter = ax.scatter([], [])\n",
180
+ " text = ax.annotate(\"\", (2,7.5))\n",
181
+ " \n",
182
+ " def animate(t):\n",
183
+ " solver_iter = t // 45\n",
184
+ " timestep = t % 45\n",
185
+ " ado_position = np.array([[ado_positions[0, timestep, 0], ado_positions[0, timestep, 1]]]) \n",
186
+ " solver_info = solver_infos[solver_iter]\n",
187
+ " biased_predicted_ado_positions = solver_info[\"ado_state_future_samples\"][:, 0, timestep, :2].reshape([-1, 2])\n",
188
+ " predicted_ado_positions = ado_predictions[:num_samples, 0 , timestep, :2].reshape([-1, 2])\n",
189
+ " len_pred = len(predicted_ado_positions)\n",
190
+ " len_biased_pred = len(biased_predicted_ado_positions)\n",
191
+ " ego_positions = solver_info[\"ego_state_future\"][:, 0, timestep, :2]\n",
192
+ " positions = np.concatenate((ado_position, predicted_ado_positions, biased_predicted_ado_positions, ego_positions)) \n",
193
+ " scatter.set_offsets(positions)\n",
194
+ " scatter.set_alpha(0.5)\n",
195
+ " colors = np.ones(len(positions))*0.4\n",
196
+ " colors[0] = 0\n",
197
+ " colors[1:len_pred+1] = 0.6\n",
198
+ " colors[len_pred+1:len_biased_pred+len_pred+1] = 0.8\n",
199
+ " scatter.set_array(colors)\n",
200
+ " total_risk = solver_info[\"total_risk\"].mean()\n",
201
+ " tracking_cost = solver_info[\"tracking_cost\"].mean()\n",
202
+ " text_str = \"solver_iter: {}, timestep: {:02d}, total risk: {:0.2f}, tracking cost: {:0.2f}\".format(\n",
203
+ " solver_iter,\n",
204
+ " timestep, \n",
205
+ " total_risk,\n",
206
+ " tracking_cost,\n",
207
+ " )\n",
208
+ " text.set_text(text_str)\n",
209
+ " return scatter, text\n",
210
+ " \n",
211
+ " num_frames = 45 * 10\n",
212
+ " anim = animation.FuncAnimation(fig, animate, frames=num_frames, interval=20, blit=True, save_count=sys.maxsize)\n",
213
+ " \n",
214
+ " return anim"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 8,
220
+ "id": "519e9eee",
221
+ "metadata": {},
222
+ "outputs": [],
223
+ "source": [
224
+ "episode_id = 66"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 9,
230
+ "id": "dcc04f38",
231
+ "metadata": {},
232
+ "outputs": [
233
+ {
234
+ "data": {
235
+ "application/javascript": "/* Put everything inside the global mpl namespace */\n/* global mpl */\nwindow.mpl = {};\n\nmpl.get_websocket_type = function () {\n if (typeof WebSocket !== 'undefined') {\n return WebSocket;\n } else if (typeof MozWebSocket !== 'undefined') {\n return MozWebSocket;\n } else {\n alert(\n 'Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.'\n );\n }\n};\n\nmpl.figure = function (figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = this.ws.binaryType !== undefined;\n\n if (!this.supports_binary) {\n var warnings = document.getElementById('mpl-warnings');\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent =\n 'This browser does not support binary websocket messages. ' +\n 'Performance may be slow.';\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = document.createElement('div');\n this.root.setAttribute('style', 'display: inline-block');\n this._root_extra_style(this.root);\n\n parent_element.appendChild(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message('supports_binary', { value: fig.supports_binary });\n fig.send_message('send_image_mode', {});\n if (fig.ratio !== 1) {\n fig.send_message('set_device_pixel_ratio', {\n device_pixel_ratio: fig.ratio,\n });\n }\n fig.send_message('refresh', {});\n };\n\n this.imageObj.onload = function () {\n if (fig.image_mode === 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function () {\n fig.ws.close();\n };\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n};\n\nmpl.figure.prototype._init_header = function () {\n var titlebar = document.createElement('div');\n titlebar.classList =\n 'ui-dialog-titlebar ui-widget-header ui-corner-all ui-helper-clearfix';\n var titletext = document.createElement('div');\n titletext.classList = 'ui-dialog-title';\n titletext.setAttribute(\n 'style',\n 'width: 100%; text-align: center; padding: 3px;'\n );\n titlebar.appendChild(titletext);\n this.root.appendChild(titlebar);\n this.header = titletext;\n};\n\nmpl.figure.prototype._canvas_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._root_extra_style = function (_canvas_div) {};\n\nmpl.figure.prototype._init_canvas = function () {\n var fig = this;\n\n var canvas_div = (this.canvas_div = document.createElement('div'));\n canvas_div.setAttribute(\n 'style',\n 'border: 1px solid #ddd;' +\n 'box-sizing: content-box;' +\n 'clear: both;' +\n 'min-height: 1px;' +\n 'min-width: 1px;' +\n 'outline: 0;' +\n 'overflow: hidden;' +\n 'position: relative;' +\n 'resize: both;'\n );\n\n function on_keyboard_event_closure(name) {\n return function (event) {\n return fig.key_event(event, name);\n };\n }\n\n canvas_div.addEventListener(\n 'keydown',\n on_keyboard_event_closure('key_press')\n );\n canvas_div.addEventListener(\n 'keyup',\n on_keyboard_event_closure('key_release')\n );\n\n this._canvas_extra_style(canvas_div);\n this.root.appendChild(canvas_div);\n\n var canvas = (this.canvas = document.createElement('canvas'));\n canvas.classList.add('mpl-canvas');\n canvas.setAttribute('style', 'box-sizing: content-box;');\n\n this.context = canvas.getContext('2d');\n\n var backingStore =\n this.context.backingStorePixelRatio ||\n this.context.webkitBackingStorePixelRatio ||\n this.context.mozBackingStorePixelRatio ||\n this.context.msBackingStorePixelRatio ||\n this.context.oBackingStorePixelRatio ||\n this.context.backingStorePixelRatio ||\n 1;\n\n this.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband_canvas = (this.rubberband_canvas = document.createElement(\n 'canvas'\n ));\n rubberband_canvas.setAttribute(\n 'style',\n 'box-sizing: content-box; position: absolute; left: 0; top: 0; z-index: 1;'\n );\n\n // Apply a ponyfill if ResizeObserver is not implemented by browser.\n if (this.ResizeObserver === undefined) {\n if (window.ResizeObserver !== undefined) {\n this.ResizeObserver = window.ResizeObserver;\n } else {\n var obs = _JSXTOOLS_RESIZE_OBSERVER({});\n this.ResizeObserver = obs.ResizeObserver;\n }\n }\n\n this.resizeObserverInstance = new this.ResizeObserver(function (entries) {\n var nentries = entries.length;\n for (var i = 0; i < nentries; i++) {\n var entry = entries[i];\n var width, height;\n if (entry.contentBoxSize) {\n if (entry.contentBoxSize instanceof Array) {\n // Chrome 84 implements new version of spec.\n width = entry.contentBoxSize[0].inlineSize;\n height = entry.contentBoxSize[0].blockSize;\n } else {\n // Firefox implements old version of spec.\n width = entry.contentBoxSize.inlineSize;\n height = entry.contentBoxSize.blockSize;\n }\n } else {\n // Chrome <84 implements even older version of spec.\n width = entry.contentRect.width;\n height = entry.contentRect.height;\n }\n\n // Keep the size of the canvas and rubber band canvas in sync with\n // the canvas container.\n if (entry.devicePixelContentBoxSize) {\n // Chrome 84 implements new version of spec.\n canvas.setAttribute(\n 'width',\n entry.devicePixelContentBoxSize[0].inlineSize\n );\n canvas.setAttribute(\n 'height',\n entry.devicePixelContentBoxSize[0].blockSize\n );\n } else {\n canvas.setAttribute('width', width * fig.ratio);\n canvas.setAttribute('height', height * fig.ratio);\n }\n canvas.setAttribute(\n 'style',\n 'width: ' + width + 'px; height: ' + height + 'px;'\n );\n\n rubberband_canvas.setAttribute('width', width);\n rubberband_canvas.setAttribute('height', height);\n\n // And update the size in Python. We ignore the initial 0/0 size\n // that occurs as the element is placed into the DOM, which should\n // otherwise not happen due to the minimum size styling.\n if (fig.ws.readyState == 1 && width != 0 && height != 0) {\n fig.request_resize(width, height);\n }\n }\n });\n this.resizeObserverInstance.observe(canvas_div);\n\n function on_mouse_event_closure(name) {\n return function (event) {\n return fig.mouse_event(event, name);\n };\n }\n\n rubberband_canvas.addEventListener(\n 'mousedown',\n on_mouse_event_closure('button_press')\n );\n rubberband_canvas.addEventListener(\n 'mouseup',\n on_mouse_event_closure('button_release')\n );\n rubberband_canvas.addEventListener(\n 'dblclick',\n on_mouse_event_closure('dblclick')\n );\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband_canvas.addEventListener(\n 'mousemove',\n on_mouse_event_closure('motion_notify')\n );\n\n rubberband_canvas.addEventListener(\n 'mouseenter',\n on_mouse_event_closure('figure_enter')\n );\n rubberband_canvas.addEventListener(\n 'mouseleave',\n on_mouse_event_closure('figure_leave')\n );\n\n canvas_div.addEventListener('wheel', function (event) {\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n on_mouse_event_closure('scroll')(event);\n });\n\n canvas_div.appendChild(canvas);\n canvas_div.appendChild(rubberband_canvas);\n\n this.rubberband_context = rubberband_canvas.getContext('2d');\n this.rubberband_context.strokeStyle = '#000000';\n\n this._resize_canvas = function (width, height, forward) {\n if (forward) {\n canvas_div.style.width = width + 'px';\n canvas_div.style.height = height + 'px';\n }\n };\n\n // Disable right mouse context menu.\n this.rubberband_canvas.addEventListener('contextmenu', function (_e) {\n event.preventDefault();\n return false;\n });\n\n function set_focus() {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'mpl-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'mpl-button-group';\n continue;\n }\n\n var button = (fig.buttons[name] = document.createElement('button'));\n button.classList = 'mpl-widget';\n button.setAttribute('role', 'button');\n button.setAttribute('aria-disabled', 'false');\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n\n var icon_img = document.createElement('img');\n icon_img.src = '_images/' + image + '.png';\n icon_img.srcset = '_images/' + image + '_large.png 2x';\n icon_img.alt = tooltip;\n button.appendChild(icon_img);\n\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n var fmt_picker = document.createElement('select');\n fmt_picker.classList = 'mpl-widget';\n toolbar.appendChild(fmt_picker);\n this.format_dropdown = fmt_picker;\n\n for (var ind in mpl.extensions) {\n var fmt = mpl.extensions[ind];\n var option = document.createElement('option');\n option.selected = fmt === mpl.default_extension;\n option.innerHTML = fmt;\n fmt_picker.appendChild(option);\n }\n\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n};\n\nmpl.figure.prototype.request_resize = function (x_pixels, y_pixels) {\n // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n // which will in turn request a refresh of the image.\n this.send_message('resize', { width: x_pixels, height: y_pixels });\n};\n\nmpl.figure.prototype.send_message = function (type, properties) {\n properties['type'] = type;\n properties['figure_id'] = this.id;\n this.ws.send(JSON.stringify(properties));\n};\n\nmpl.figure.prototype.send_draw_message = function () {\n if (!this.waiting) {\n this.waiting = true;\n this.ws.send(JSON.stringify({ type: 'draw', figure_id: this.id }));\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n var format_dropdown = fig.format_dropdown;\n var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n fig.ondownload(fig, format);\n};\n\nmpl.figure.prototype.handle_resize = function (fig, msg) {\n var size = msg['size'];\n if (size[0] !== fig.canvas.width || size[1] !== fig.canvas.height) {\n fig._resize_canvas(size[0], size[1], msg['forward']);\n fig.send_message('refresh', {});\n }\n};\n\nmpl.figure.prototype.handle_rubberband = function (fig, msg) {\n var x0 = msg['x0'] / fig.ratio;\n var y0 = (fig.canvas.height - msg['y0']) / fig.ratio;\n var x1 = msg['x1'] / fig.ratio;\n var y1 = (fig.canvas.height - msg['y1']) / fig.ratio;\n x0 = Math.floor(x0) + 0.5;\n y0 = Math.floor(y0) + 0.5;\n x1 = Math.floor(x1) + 0.5;\n y1 = Math.floor(y1) + 0.5;\n var min_x = Math.min(x0, x1);\n var min_y = Math.min(y0, y1);\n var width = Math.abs(x1 - x0);\n var height = Math.abs(y1 - y0);\n\n fig.rubberband_context.clearRect(\n 0,\n 0,\n fig.canvas.width / fig.ratio,\n fig.canvas.height / fig.ratio\n );\n\n fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n};\n\nmpl.figure.prototype.handle_figure_label = function (fig, msg) {\n // Updates the figure title.\n fig.header.textContent = msg['label'];\n};\n\nmpl.figure.prototype.handle_cursor = function (fig, msg) {\n fig.rubberband_canvas.style.cursor = msg['cursor'];\n};\n\nmpl.figure.prototype.handle_message = function (fig, msg) {\n fig.message.textContent = msg['message'];\n};\n\nmpl.figure.prototype.handle_draw = function (fig, _msg) {\n // Request the server to send over a new figure.\n fig.send_draw_message();\n};\n\nmpl.figure.prototype.handle_image_mode = function (fig, msg) {\n fig.image_mode = msg['mode'];\n};\n\nmpl.figure.prototype.handle_history_buttons = function (fig, msg) {\n for (var key in msg) {\n if (!(key in fig.buttons)) {\n continue;\n }\n fig.buttons[key].disabled = !msg[key];\n fig.buttons[key].setAttribute('aria-disabled', !msg[key]);\n }\n};\n\nmpl.figure.prototype.handle_navigate_mode = function (fig, msg) {\n if (msg['mode'] === 'PAN') {\n fig.buttons['Pan'].classList.add('active');\n fig.buttons['Zoom'].classList.remove('active');\n } else if (msg['mode'] === 'ZOOM') {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.add('active');\n } else {\n fig.buttons['Pan'].classList.remove('active');\n fig.buttons['Zoom'].classList.remove('active');\n }\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Called whenever the canvas gets updated.\n this.send_message('ack', {});\n};\n\n// A function to construct a web socket function for onmessage handling.\n// Called in the figure constructor.\nmpl.figure.prototype._make_on_message_function = function (fig) {\n return function socket_on_message(evt) {\n if (evt.data instanceof Blob) {\n var img = evt.data;\n if (img.type !== 'image/png') {\n /* FIXME: We get \"Resource interpreted as Image but\n * transferred with MIME type text/plain:\" errors on\n * Chrome. But how to set the MIME type? It doesn't seem\n * to be part of the websocket stream */\n img.type = 'image/png';\n }\n\n /* Free the memory for the previous frames */\n if (fig.imageObj.src) {\n (window.URL || window.webkitURL).revokeObjectURL(\n fig.imageObj.src\n );\n }\n\n fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n img\n );\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n } else if (\n typeof evt.data === 'string' &&\n evt.data.slice(0, 21) === 'data:image/png;base64'\n ) {\n fig.imageObj.src = evt.data;\n fig.updated_canvas_event();\n fig.waiting = false;\n return;\n }\n\n var msg = JSON.parse(evt.data);\n var msg_type = msg['type'];\n\n // Call the \"handle_{type}\" callback, which takes\n // the figure and JSON message as its only arguments.\n try {\n var callback = fig['handle_' + msg_type];\n } catch (e) {\n console.log(\n \"No handler for the '\" + msg_type + \"' message type: \",\n msg\n );\n return;\n }\n\n if (callback) {\n try {\n // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n callback(fig, msg);\n } catch (e) {\n console.log(\n \"Exception inside the 'handler_\" + msg_type + \"' callback:\",\n e,\n e.stack,\n msg\n );\n }\n }\n };\n};\n\n// from https://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\nmpl.findpos = function (e) {\n //this section is from http://www.quirksmode.org/js/events_properties.html\n var targ;\n if (!e) {\n e = window.event;\n }\n if (e.target) {\n targ = e.target;\n } else if (e.srcElement) {\n targ = e.srcElement;\n }\n if (targ.nodeType === 3) {\n // defeat Safari bug\n targ = targ.parentNode;\n }\n\n // pageX,Y are the mouse positions relative to the document\n var boundingRect = targ.getBoundingClientRect();\n var x = e.pageX - (boundingRect.left + document.body.scrollLeft);\n var y = e.pageY - (boundingRect.top + document.body.scrollTop);\n\n return { x: x, y: y };\n};\n\n/*\n * return a copy of an object with only non-object keys\n * we need this to avoid circular references\n * https://stackoverflow.com/a/24161582/3208463\n */\nfunction simpleKeys(original) {\n return Object.keys(original).reduce(function (obj, key) {\n if (typeof original[key] !== 'object') {\n obj[key] = original[key];\n }\n return obj;\n }, {});\n}\n\nmpl.figure.prototype.mouse_event = function (event, name) {\n var canvas_pos = mpl.findpos(event);\n\n if (name === 'button_press') {\n this.canvas.focus();\n this.canvas_div.focus();\n }\n\n var x = canvas_pos.x * this.ratio;\n var y = canvas_pos.y * this.ratio;\n\n this.send_message(name, {\n x: x,\n y: y,\n button: event.button,\n step: event.step,\n guiEvent: simpleKeys(event),\n });\n\n /* This prevents the web browser from automatically changing to\n * the text insertion cursor when the button is pressed. We want\n * to control all of the cursor setting manually through the\n * 'cursor' event from matplotlib */\n event.preventDefault();\n return false;\n};\n\nmpl.figure.prototype._key_event_extra = function (_event, _name) {\n // Handle any extra behaviour associated with a key event\n};\n\nmpl.figure.prototype.key_event = function (event, name) {\n // Prevent repeat events\n if (name === 'key_press') {\n if (event.key === this._key) {\n return;\n } else {\n this._key = event.key;\n }\n }\n if (name === 'key_release') {\n this._key = null;\n }\n\n var value = '';\n if (event.ctrlKey && event.key !== 'Control') {\n value += 'ctrl+';\n }\n else if (event.altKey && event.key !== 'Alt') {\n value += 'alt+';\n }\n else if (event.shiftKey && event.key !== 'Shift') {\n value += 'shift+';\n }\n\n value += 'k' + event.key;\n\n this._key_event_extra(event, name);\n\n this.send_message(name, { key: value, guiEvent: simpleKeys(event) });\n return false;\n};\n\nmpl.figure.prototype.toolbar_button_onclick = function (name) {\n if (name === 'download') {\n this.handle_save(this, null);\n } else {\n this.send_message('toolbar_button', { name: name });\n }\n};\n\nmpl.figure.prototype.toolbar_button_onmouseover = function (tooltip) {\n this.message.textContent = tooltip;\n};\n\n///////////////// REMAINING CONTENT GENERATED BY embed_js.py /////////////////\n// prettier-ignore\nvar _JSXTOOLS_RESIZE_OBSERVER=function(A){var t,i=new WeakMap,n=new WeakMap,a=new WeakMap,r=new WeakMap,o=new Set;function s(e){if(!(this instanceof s))throw new TypeError(\"Constructor requires 'new' operator\");i.set(this,e)}function h(){throw new TypeError(\"Function is not a constructor\")}function c(e,t,i,n){e=0 in arguments?Number(arguments[0]):0,t=1 in arguments?Number(arguments[1]):0,i=2 in arguments?Number(arguments[2]):0,n=3 in arguments?Number(arguments[3]):0,this.right=(this.x=this.left=e)+(this.width=i),this.bottom=(this.y=this.top=t)+(this.height=n),Object.freeze(this)}function d(){t=requestAnimationFrame(d);var s=new WeakMap,p=new Set;o.forEach((function(t){r.get(t).forEach((function(i){var r=t instanceof window.SVGElement,o=a.get(t),d=r?0:parseFloat(o.paddingTop),f=r?0:parseFloat(o.paddingRight),l=r?0:parseFloat(o.paddingBottom),u=r?0:parseFloat(o.paddingLeft),g=r?0:parseFloat(o.borderTopWidth),m=r?0:parseFloat(o.borderRightWidth),w=r?0:parseFloat(o.borderBottomWidth),b=u+f,F=d+l,v=(r?0:parseFloat(o.borderLeftWidth))+m,W=g+w,y=r?0:t.offsetHeight-W-t.clientHeight,E=r?0:t.offsetWidth-v-t.clientWidth,R=b+v,z=F+W,M=r?t.width:parseFloat(o.width)-R-E,O=r?t.height:parseFloat(o.height)-z-y;if(n.has(t)){var k=n.get(t);if(k[0]===M&&k[1]===O)return}n.set(t,[M,O]);var S=Object.create(h.prototype);S.target=t,S.contentRect=new c(u,d,M,O),s.has(i)||(s.set(i,[]),p.add(i)),s.get(i).push(S)}))})),p.forEach((function(e){i.get(e).call(e,s.get(e),e)}))}return s.prototype.observe=function(i){if(i instanceof window.Element){r.has(i)||(r.set(i,new Set),o.add(i),a.set(i,window.getComputedStyle(i)));var n=r.get(i);n.has(this)||n.add(this),cancelAnimationFrame(t),t=requestAnimationFrame(d)}},s.prototype.unobserve=function(i){if(i instanceof window.Element&&r.has(i)){var n=r.get(i);n.has(this)&&(n.delete(this),n.size||(r.delete(i),o.delete(i))),n.size||r.delete(i),o.size||cancelAnimationFrame(t)}},A.DOMRectReadOnly=c,A.ResizeObserver=s,A.ResizeObserverEntry=h,A}; // eslint-disable-line\nmpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Left button pans, Right button zooms\\nx/y fixes axis, CTRL fixes aspect\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\\nx/y fixes axis\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n\nmpl.extensions = [\"eps\", \"jpeg\", \"pgf\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n\nmpl.default_extension = \"png\";/* global mpl */\n\nvar comm_websocket_adapter = function (comm) {\n // Create a \"websocket\"-like object which calls the given IPython comm\n // object with the appropriate methods. Currently this is a non binary\n // socket, so there is still some room for performance tuning.\n var ws = {};\n\n ws.binaryType = comm.kernel.ws.binaryType;\n ws.readyState = comm.kernel.ws.readyState;\n function updateReadyState(_event) {\n if (comm.kernel.ws) {\n ws.readyState = comm.kernel.ws.readyState;\n } else {\n ws.readyState = 3; // Closed state.\n }\n }\n comm.kernel.ws.addEventListener('open', updateReadyState);\n comm.kernel.ws.addEventListener('close', updateReadyState);\n comm.kernel.ws.addEventListener('error', updateReadyState);\n\n ws.close = function () {\n comm.close();\n };\n ws.send = function (m) {\n //console.log('sending', m);\n comm.send(m);\n };\n // Register the callback with on_msg.\n comm.on_msg(function (msg) {\n //console.log('receiving', msg['content']['data'], msg);\n var data = msg['content']['data'];\n if (data['blob'] !== undefined) {\n data = {\n data: new Blob(msg['buffers'], { type: data['blob'] }),\n };\n }\n // Pass the mpl event to the overridden (by mpl) onmessage function.\n ws.onmessage(data);\n });\n return ws;\n};\n\nmpl.mpl_figure_comm = function (comm, msg) {\n // This is the function which gets called when the mpl process\n // starts-up an IPython Comm through the \"matplotlib\" channel.\n\n var id = msg.content.data.id;\n // Get hold of the div created by the display call when the Comm\n // socket was opened in Python.\n var element = document.getElementById(id);\n var ws_proxy = comm_websocket_adapter(comm);\n\n function ondownload(figure, _format) {\n window.open(figure.canvas.toDataURL());\n }\n\n var fig = new mpl.figure(id, ws_proxy, ondownload, element);\n\n // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n // web socket which is closed, not our websocket->open comm proxy.\n ws_proxy.onopen();\n\n fig.parent_element = element;\n fig.cell_info = mpl.find_output_cell(\"<div id='\" + id + \"'></div>\");\n if (!fig.cell_info) {\n console.error('Failed to find cell for figure', id, fig);\n return;\n }\n fig.cell_info[0].output_area.element.on(\n 'cleared',\n { fig: fig },\n fig._remove_fig_handler\n );\n};\n\nmpl.figure.prototype.handle_close = function (fig, msg) {\n var width = fig.canvas.width / fig.ratio;\n fig.cell_info[0].output_area.element.off(\n 'cleared',\n fig._remove_fig_handler\n );\n fig.resizeObserverInstance.unobserve(fig.canvas_div);\n\n // Update the output cell to use the data from the current canvas.\n fig.push_to_output();\n var dataURL = fig.canvas.toDataURL();\n // Re-enable the keyboard manager in IPython - without this line, in FF,\n // the notebook keyboard shortcuts fail.\n IPython.keyboard_manager.enable();\n fig.parent_element.innerHTML =\n '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n fig.close_ws(fig, msg);\n};\n\nmpl.figure.prototype.close_ws = function (fig, msg) {\n fig.send_message('closing', msg);\n // fig.ws.close()\n};\n\nmpl.figure.prototype.push_to_output = function (_remove_interactive) {\n // Turn the data on the canvas into data in the output cell.\n var width = this.canvas.width / this.ratio;\n var dataURL = this.canvas.toDataURL();\n this.cell_info[1]['text/html'] =\n '<img src=\"' + dataURL + '\" width=\"' + width + '\">';\n};\n\nmpl.figure.prototype.updated_canvas_event = function () {\n // Tell IPython that the notebook contents must change.\n IPython.notebook.set_dirty(true);\n this.send_message('ack', {});\n var fig = this;\n // Wait a second, then push the new image to the DOM so\n // that it is saved nicely (might be nice to debounce this).\n setTimeout(function () {\n fig.push_to_output();\n }, 1000);\n};\n\nmpl.figure.prototype._init_toolbar = function () {\n var fig = this;\n\n var toolbar = document.createElement('div');\n toolbar.classList = 'btn-toolbar';\n this.root.appendChild(toolbar);\n\n function on_click_closure(name) {\n return function (_event) {\n return fig.toolbar_button_onclick(name);\n };\n }\n\n function on_mouseover_closure(tooltip) {\n return function (event) {\n if (!event.currentTarget.disabled) {\n return fig.toolbar_button_onmouseover(tooltip);\n }\n };\n }\n\n fig.buttons = {};\n var buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n var button;\n for (var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n /* Instead of a spacer, we start a new button group. */\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n buttonGroup = document.createElement('div');\n buttonGroup.classList = 'btn-group';\n continue;\n }\n\n button = fig.buttons[name] = document.createElement('button');\n button.classList = 'btn btn-default';\n button.href = '#';\n button.title = name;\n button.innerHTML = '<i class=\"fa ' + image + ' fa-lg\"></i>';\n button.addEventListener('click', on_click_closure(method_name));\n button.addEventListener('mouseover', on_mouseover_closure(tooltip));\n buttonGroup.appendChild(button);\n }\n\n if (buttonGroup.hasChildNodes()) {\n toolbar.appendChild(buttonGroup);\n }\n\n // Add the status bar.\n var status_bar = document.createElement('span');\n status_bar.classList = 'mpl-message pull-right';\n toolbar.appendChild(status_bar);\n this.message = status_bar;\n\n // Add the close button to the window.\n var buttongrp = document.createElement('div');\n buttongrp.classList = 'btn-group inline pull-right';\n button = document.createElement('button');\n button.classList = 'btn btn-mini btn-primary';\n button.href = '#';\n button.title = 'Stop Interaction';\n button.innerHTML = '<i class=\"fa fa-power-off icon-remove icon-large\"></i>';\n button.addEventListener('click', function (_evt) {\n fig.handle_close(fig, {});\n });\n button.addEventListener(\n 'mouseover',\n on_mouseover_closure('Stop Interaction')\n );\n buttongrp.appendChild(button);\n var titlebar = this.root.querySelector('.ui-dialog-titlebar');\n titlebar.insertBefore(buttongrp, titlebar.firstChild);\n};\n\nmpl.figure.prototype._remove_fig_handler = function (event) {\n var fig = event.data.fig;\n if (event.target !== this) {\n // Ignore bubbled events from children.\n return;\n }\n fig.close_ws(fig, {});\n};\n\nmpl.figure.prototype._root_extra_style = function (el) {\n el.style.boxSizing = 'content-box'; // override notebook setting of border-box.\n};\n\nmpl.figure.prototype._canvas_extra_style = function (el) {\n // this is important to make the div 'focusable\n el.setAttribute('tabindex', 0);\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n } else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n};\n\nmpl.figure.prototype._key_event_extra = function (event, _name) {\n // Check for shift+enter\n if (event.shiftKey && event.which === 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n};\n\nmpl.figure.prototype.handle_save = function (fig, _msg) {\n fig.ondownload(fig, null);\n};\n\nmpl.find_output_cell = function (html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i = 0; i < ncells; i++) {\n var cell = cells[i];\n if (cell.cell_type === 'code') {\n for (var j = 0; j < cell.output_area.outputs.length; j++) {\n var data = cell.output_area.outputs[j];\n if (data.data) {\n // IPython >= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] === html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n};\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel !== null) {\n IPython.notebook.kernel.comm_manager.register_target(\n 'matplotlib',\n mpl.mpl_figure_comm\n );\n}\n",
236
+ "text/plain": [
237
+ "<IPython.core.display.Javascript object>"
238
+ ]
239
+ },
240
+ "metadata": {},
241
+ "output_type": "display_data"
242
+ },
243
+ {
244
+ "data": {
245
+ "text/html": [
246
+ "<img src=\"\" width=\"640\">"
247
+ ],
248
+ "text/plain": [
249
+ "<IPython.core.display.HTML object>"
250
+ ]
251
+ },
252
+ "metadata": {},
253
+ "output_type": "display_data"
254
+ }
255
+ ],
256
+ "source": [
257
+ "anim = animate_episode(predictor_data[episode_id][\"solver_info\"], predictor_data[episode_id][\"ado_position_future\"], predictor_data[episode_id][\"ado_unbiased_predictions\"])\n",
258
+ "anim.save(os.path.join(stats_dir, f\"ep_{episode_id}_predictor.mp4\"), writer=\"ffmpeg\")"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "39c17b02",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": []
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "id": "144aec04",
273
+ "metadata": {},
274
+ "outputs": [],
275
+ "source": []
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "b5dbaa75",
281
+ "metadata": {},
282
+ "outputs": [],
283
+ "source": []
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "id": "a86d7330",
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": []
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "id": "da0e8629",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": []
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "id": "342b6e50",
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": []
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "11cf1be2",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": []
316
+ }
317
+ ],
318
+ "metadata": {
319
+ "kernelspec": {
320
+ "display_name": "Python 3 (ipykernel)",
321
+ "language": "python",
322
+ "name": "python3"
323
+ },
324
+ "language_info": {
325
+ "codemirror_mode": {
326
+ "name": "ipython",
327
+ "version": 3
328
+ },
329
+ "file_extension": ".py",
330
+ "mimetype": "text/x-python",
331
+ "name": "python",
332
+ "nbconvert_exporter": "python",
333
+ "pygments_lexer": "ipython3",
334
+ "version": "3.7.13"
335
+ },
336
+ "vscode": {
337
+ "interpreter": {
338
+ "hash": "a1479098f155d52cde87e35cb1613d4d825087d81bb03677a9d084ad747a84cc"
339
+ }
340
+ }
341
+ },
342
+ "nbformat": 4,
343
+ "nbformat_minor": 5
344
+ }
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.12
2
+ matplotlib
3
+ numpy
4
+ mmcv>=1.4.7
5
+ pytorch-lightning
6
+ pytest
7
+ setuptools>=59.5.0
8
+ wandb
9
+ plotly
10
+ scipy
11
+ gradio>=3.7
12
+ datasets
13
+ huggingface_hub
14
+ einops
15
+ pydantic
16
+ fire
risk_biased/__init__.py ADDED
File without changes
risk_biased/config/learning_config.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from risk_biased.config.paths import (
2
+ log_path,
3
+ )
4
+
5
+ # WandB Project Name
6
+ project = "RiskBiased"
7
+ entity = "tri"
8
+
9
+ # Scene Parameters
10
+ dt = 0.1
11
+ time_scene = 5.0
12
+ sample_times = [t * dt for t in range(0, int(time_scene / dt))]
13
+ ego_ref_speed = 14.0
14
+ ego_length = 4.0
15
+ ego_width = 1.75
16
+ fast_speed = 2.0
17
+ slow_speed = 1.0
18
+ p_change_pace = 0.2
19
+ proportion_fast = 0.5
20
+
21
+ # Data Parameters
22
+ file_name = "scene_data"
23
+ datasets_sizes = {"train": 100000, "val": 10000, "test": 30000}
24
+ datasets = list(datasets_sizes.keys())
25
+ state_dim = 2
26
+ dynamic_state_dim = 2
27
+ num_steps = 5
28
+ num_steps_future = len(sample_times) - num_steps
29
+ ego_speed_init_low = 4.0
30
+ ego_speed_init_high = 16.0
31
+ ego_acceleration_mean_low = -1.5
32
+ ego_acceleration_mean_high = 1.5
33
+ ego_acceleration_std = 3.0
34
+ perception_noise_std = 0.05
35
+ map_state_dim = 0
36
+ max_size_lane = 0
37
+ num_blocks = 3
38
+ interaction_type = None
39
+ mcg_dim_expansion = 0
40
+ mcg_num_layers = 0
41
+ num_attention_heads = 4
42
+
43
+
44
+ # Model Hyperparameters
45
+ model_type = "encoder_biased"
46
+ condition_on_ego_future = True
47
+ latent_dim = 2
48
+ hidden_dim = 64
49
+ num_vq = 256
50
+ latent_distribution = "gaussian" # "gaussian" or "quantized"
51
+ num_hidden_layers = 3
52
+ sequence_encoder_type = "MLP" # one of "MLP", "LSTM", "maskedLSTM"
53
+ sequence_decoder_type = "MLP" # one of "MLP", "LSTM", "maskedLSTM"
54
+ is_mlp_residual = True
55
+
56
+ # Variational Loss Hyperparameters
57
+ kl_weight = 0.5 # For the didactic example with gaussian latent kl_weight = 0.5 is a good value, with quantized latent kl_weight = 0.1 is a good value
58
+ kl_threshold = 0.1
59
+ latent_regularization = 0.1
60
+
61
+ # Risk distribution should be one of the following types :
62
+ # {"type": "uniform", "min": 0, "max": 1},
63
+ # {"type": "normal", "mean": 0, "sigma": 1},
64
+ # {"type": "bernoulli", "p": 0.5, "min": 0, "max": 1},
65
+ # {"type": "beta", "alpha": 2, "beta": 5, "min": 0, "max": 1},
66
+ # {"type": "chi2", "k": 3, "min": 0, "scale": 1},
67
+ # {"type": "log-normal", "mu": 0, "sigma": 1, "min": 0, "scale": 1}
68
+ # {"type": "log-uniform", "min": 0, "max": 1, "scale": 1}
69
+ risk_distribution = {"type": "log-uniform", "min": 0, "max": 1, "scale": 3}
70
+
71
+
72
+ # Monte Carlo risk estimator should be one of the following types :
73
+ # {"type": "entropic", "eps": 1e-4}
74
+ # {"type": "cvar", "eps": 1e-4}
75
+
76
+ risk_estimator = {"type": "cvar", "eps": 1e-3}
77
+ if latent_distribution == "quantized":
78
+ # Number of samples used to estimate the risk from the unbiased distribution
79
+ n_mc_samples_risk = num_vq
80
+ # Number of samples used to estimate the averaged cost of the biased distribution
81
+ n_mc_samples_biased = num_vq
82
+ else:
83
+ # Number of samples used to estimate the risk from the unbiased distribution
84
+ n_mc_samples_risk = 512
85
+ # Number of samples used to estimate the averaged cost of the biased distribution
86
+ n_mc_samples_biased = 256
87
+
88
+
89
+ # Risk Loss Hyperparameters
90
+ risk_weight = 1
91
+ risk_assymetry_factor = 200
92
+ use_risk_constraint = True # For encoder_biased only
93
+ risk_constraint_update_every_n_epoch = (
94
+ 1 # For encoder_biased only, not used if use_risk_constraint == False
95
+ )
96
+ risk_constraint_weight_update_factor = (
97
+ 1.5 # For encoder_biased only, not used if use_risk_constraint == False
98
+ )
99
+ risk_constraint_weight_maximum = (
100
+ 1e5 # For encoder_biased only, not used if use_risk_constraint == False
101
+ )
102
+
103
+
104
+ # Training Hyperparameters
105
+ learning_rate = 1e-4
106
+ batch_size = 512
107
+ num_epochs_cvae = 100
108
+ num_epochs_bias = 100
109
+ gpus = [0]
110
+ seed = 0 # Give an integer value to seed will set seed for pseudo-random number generators in: pytorch, numpy, python.random
111
+ early_stopping = False
112
+ accumulate_grad_batches = 1
113
+
114
+ num_workers = 4
115
+ log_weights_and_grads = False
116
+ num_samples_min_fde = 16
117
+ val_check_interval_epoch = 1
118
+ plot_interval_epoch = 1
119
+ histogram_interval_epoch = 1
120
+
121
+ # State Cost Hyperparameters
122
+ cost_scale = 10
123
+ cost_reduce = (
124
+ "mean" # choose in "discounted_mean", "mean", "min", "max", "now", "final"
125
+ )
126
+ discount_factor = 0.95 # only used if cost_reduce == "discounted_mean", discounts the cost by this factor at each time step
127
+ distance_bandwidth = 2
128
+ time_bandwidth = 0.5
129
+ min_velocity_diff = 0.03
130
+
131
+
132
+ # List all above parameters that make a difference in the dataset to distringuish datasets once generated
133
+ dataset_parameters = {
134
+ "dt": dt,
135
+ "time_scene": time_scene,
136
+ "sample_times": sample_times,
137
+ "ego_ref_speed": ego_ref_speed,
138
+ "ego_speed_init_low": ego_speed_init_low,
139
+ "ego_speed_init_high": ego_speed_init_high,
140
+ "ego_acceleration_mean_low": ego_acceleration_mean_low,
141
+ "ego_acceleration_mean_high": ego_acceleration_mean_high,
142
+ "ego_acceleration_std": ego_acceleration_std,
143
+ "fast_speed": fast_speed,
144
+ "slow_speed": slow_speed,
145
+ "p_change_pace": p_change_pace,
146
+ "proportion_fast": proportion_fast,
147
+ "file_name": file_name,
148
+ "datasets_sizes": datasets_sizes,
149
+ "state_dim": state_dim,
150
+ "num_steps": num_steps,
151
+ "num_steps_future": num_steps_future,
152
+ "perception_noise_std": perception_noise_std,
153
+ }
154
+
155
+ # List files that should be saved as log
156
+ files_to_log = []
risk_biased/config/paths.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Data split path:
2
+ base_path = "<set your path to the pre-processed data directory here>"
3
+ data_dir = "interactive_veh_type" # This directory name conditions the model hyperparameters, make sure to set it correctly
4
+ sample_dataset_path = base_path + data_dir + "/sample"
5
+ val_dataset_path = base_path + data_dir + "/validation"
6
+ train_dataset_path = base_path + data_dir + "/training"
7
+ test_dataset_path = base_path + data_dir + "/sample"
8
+
9
+ log_path = "<set your path to any directory where you want the logs to be stored>"
risk_biased/config/planning_config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tracking Cost Parameters
2
+ tracking_cost_scale_longitudinal = 1e-2
3
+ tracking_cost_scale_lateral = 1e-0
4
+ tracking_cost_reduce = "mean"
5
+
6
+ # Cross Entropy Solver Parameters
7
+ num_control_samples = 100
8
+ num_elite = 30
9
+ iter_max = 10
10
+ smoothing_factor = 0.2
11
+ mean_warm_start = True
12
+ acceleration_std_x_m_s2 = 2.0
13
+ acceleration_std_y_m_s2 = 0.0
risk_biased/config/waymo_config.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from risk_biased.config.paths import (
2
+ data_dir,
3
+ sample_dataset_path,
4
+ val_dataset_path,
5
+ train_dataset_path,
6
+ test_dataset_path,
7
+ log_path,
8
+ )
9
+
10
+ # Data augmentation:
11
+ normalize_angle = True
12
+ random_rotation = False
13
+ angle_std = 3.14 / 4
14
+ random_translation = False
15
+ translation_distance_std = 0.1
16
+ p_exchange_two_first = 0.5
17
+
18
+ # Data diminution:
19
+ min_num_observation = 2
20
+ max_size_lane = 50
21
+ train_dataset_size_limit = None
22
+ val_dataset_size_limit = None
23
+ max_num_agents = 50
24
+ max_num_objects = 50
25
+
26
+ # Data caracterization:
27
+ time_scene = 9.1
28
+ dt = 0.1
29
+ num_steps = 11
30
+ num_steps_future = 80
31
+
32
+ # TODO: avoid conditioning on the name of the directory in the path
33
+ if data_dir == "interactive_veh_type":
34
+ map_state_dim = 2 + num_steps * 8
35
+ state_dim = 11
36
+ dynamic_state_dim = 5
37
+ elif data_dir == "interactive_full":
38
+ map_state_dim = 2
39
+ state_dim = 5
40
+ dynamic_state_dim = 5
41
+ else:
42
+ map_state_dim = 2
43
+ state_dim = 2
44
+ dynamic_state_dim = 2
45
+
46
+ # Variational Loss Hyperparameters
47
+ kl_weight = 1.0
48
+ kl_threshold = 0.01
49
+
50
+ # Training Parameters
51
+ learning_rate = 3e-4
52
+ batch_size = 64
53
+ accumulate_grad_batches = 2
54
+ num_epochs_cvae = 0
55
+ num_epochs_bias = 100
56
+ gpus = [1]
57
+ seed = 0 # Give an integer value to seed will set seed for pseudo-random number generators in: pytorch, numpy, python.random
58
+ num_workers = 8
59
+
60
+ # Model hyperparameter
61
+ model_type = "interaction_biased"
62
+ condition_on_ego_future = False
63
+ latent_dim = 16
64
+ hidden_dim = 128
65
+ feature_dim = 16
66
+ num_vq = 512
67
+ latent_distribution = "gaussian" # "gaussian" or "quantized"
68
+ is_mlp_residual = True
69
+ num_hidden_layers = 3
70
+ num_blocks = 3
71
+ interaction_type = "Attention" # one of "ContextGating", "Attention", "Hybrid"
72
+ ## MCG parameters
73
+ mcg_dim_expansion = 2
74
+ mcg_num_layers = 0
75
+ ## Attention parameters
76
+ num_attention_heads = 4
77
+ sequence_encoder_type = "MLP" # one of "MLP", "LSTM", "maskedLSTM"
78
+ sequence_decoder_type = "MLP" # one of "MLP", "LSTM"
79
+
80
+
81
+ # Risk Loss Hyperparameters
82
+ cost_reduce = "discounted_mean" # choose in "discounted_mean", "mean", "min", "max", "now", "final"
83
+ discount_factor = 0.95 # only used if cost_reduce == "discounted_mean", discounts the cost by this factor at each time step
84
+ min_velocity_diff = 0.1
85
+ n_mc_samples_risk = 32
86
+ n_mc_samples_biased = 16
87
+ risk_weight = 1
88
+ risk_assymetry_factor = 30
89
+ use_risk_constraint = True # For encoder_biased only
90
+ risk_constraint_update_every_n_epoch = (
91
+ 1 # For encoder_biased only, not used if use_risk_constraint == False
92
+ )
93
+ risk_constraint_weight_update_factor = (
94
+ 1.5 # For encoder_biased only, not used if use_risk_constraint == False
95
+ )
96
+ risk_constraint_weight_maximum = (
97
+ 1000 # For encoder_biased only, not used if use_risk_constraint == False
98
+ )
99
+
100
+ # List files that should be saved as log
101
+ files_to_log = [
102
+ "./risk_biased/models/biased_cvae_model.py",
103
+ "./risk_biased/models/latent_distributions.py",
104
+ ]
risk_biased/models/__init__.py ADDED
File without changes
risk_biased/models/biased_cvae_model.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from warnings import warn
2
+ from typing import Callable, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from einops import rearrange, repeat
7
+
8
+ from risk_biased.models.map_encoder import MapEncoderNN
9
+ from risk_biased.models.mlp import MLP
10
+ from risk_biased.models.cvae_params import CVAEParams
11
+ from risk_biased.models.cvae_encoders import (
12
+ AbstractLatentDistribution,
13
+ CVAEEncoder,
14
+ BiasedEncoderNN,
15
+ FutureEncoderNN,
16
+ InferenceEncoderNN,
17
+ )
18
+ from risk_biased.models.cvae_decoder import (
19
+ CVAEAccelerationDecoder,
20
+ CVAEParametrizedDecoder,
21
+ DecoderNN,
22
+ )
23
+ from risk_biased.utils.cost import BaseCostTorch, get_cost
24
+ from risk_biased.utils.loss import (
25
+ reconstruction_loss,
26
+ risk_loss_function,
27
+ )
28
+ from risk_biased.models.latent_distributions import (
29
+ GaussianLatentDistribution,
30
+ QuantizedDistributionCreator,
31
+ AbstractLatentDistribution,
32
+ )
33
+ from risk_biased.utils.metrics import FDE, minFDE
34
+ from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator
35
+
36
+
37
+ class InferenceBiasedCVAE(nn.Module):
38
+ """CVAE with a biased encoder module for risk-biased trajectory forecasting.
39
+
40
+ Args:
41
+ absolute_encoder: encoder model for the absolute positions of the agents
42
+ map_encoder: encoder model for map objects
43
+ biased_encoder: biased encoder that uses past and auxiliary input,
44
+ inference_encoder: inference encoder that uses only past,
45
+ decoder: CVAE decoder model
46
+ prior_distribution: prior distribution for the latent space.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ absolute_encoder: MLP,
52
+ map_encoder: MapEncoderNN,
53
+ biased_encoder: CVAEEncoder,
54
+ inference_encoder: CVAEEncoder,
55
+ decoder: CVAEAccelerationDecoder,
56
+ prior_distribution: AbstractLatentDistribution,
57
+ ) -> None:
58
+ super().__init__()
59
+ self.biased_encoder = biased_encoder
60
+ self.inference_encoder = inference_encoder
61
+ self.decoder = decoder
62
+ self.map_encoder = map_encoder
63
+ self.absolute_encoder = absolute_encoder
64
+ self.prior_distribution = prior_distribution
65
+
66
+ def cvae_parameters(self, recurse: bool = True):
67
+ """Define an iterator over all the parameters related to the cvae."""
68
+ yield from self.absolute_encoder.parameters(recurse=recurse)
69
+ yield from self.map_encoder.parameters(recurse=recurse)
70
+ yield from self.inference_encoder.parameters(recurse=recurse)
71
+ yield from self.decoder.parameters(recurse=recurse)
72
+
73
+ def biased_parameters(self, recurse: bool = True):
74
+ """Define an iterator over only the parameters related to the biaser."""
75
+ yield from self.biased_encoder.biased_parameters(recurse=recurse)
76
+
77
+ def forward(
78
+ self,
79
+ x: torch.Tensor,
80
+ mask_x: torch.Tensor,
81
+ map: torch.Tensor,
82
+ mask_map: torch.Tensor,
83
+ offset: torch.Tensor,
84
+ *,
85
+ x_ego: Optional[torch.Tensor] = None,
86
+ y_ego: Optional[torch.Tensor] = None,
87
+ risk_level: Optional[torch.Tensor] = None,
88
+ n_samples: int = 0,
89
+ ) -> Tuple[torch.Tensor, AbstractLatentDistribution]:
90
+ """Forward function that outputs a noisy reconstruction of y and parameters of latent
91
+ posterior distribution
92
+
93
+ Args:
94
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
95
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
96
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
97
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
98
+ offset : (batch_size, num_agents, state_dim) offset position from ego. Defaults to None.
99
+ x_ego: (batch_size, 1, num_steps, state_dim) ego history
100
+ y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
101
+ risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
102
+ trajectories. Defaults to None.
103
+ n_samples (optional): number of samples to predict, (if 0 one sample with no extra
104
+ dimension). Defaults to 0.
105
+
106
+ Returns:
107
+ noisy reconstruction y of size (batch_size, num_agents, num_steps_future, state_dim), as well as
108
+ weights of the samples and the latent distribution.
109
+ No bias is applied to encoder without offset or risk.
110
+ """
111
+
112
+ encoded_map = self.map_encoder(map, mask_map)
113
+ mask_map = mask_map.any(-1)
114
+ encoded_absolute = self.absolute_encoder(offset)
115
+
116
+ if risk_level is not None:
117
+ biased_latent_distribution = self.biased_encoder(
118
+ x,
119
+ mask_x,
120
+ encoded_absolute,
121
+ encoded_map,
122
+ mask_map,
123
+ x_ego=x_ego,
124
+ y_ego=y_ego,
125
+ offset=offset,
126
+ risk_level=risk_level,
127
+ )
128
+ inference_latent_distribution = self.inference_encoder(
129
+ x,
130
+ mask_x,
131
+ encoded_absolute,
132
+ encoded_map,
133
+ mask_map,
134
+ )
135
+ latent_distribution = inference_latent_distribution.average(
136
+ biased_latent_distribution, risk_level.unsqueeze(-1)
137
+ )
138
+ else:
139
+ latent_distribution = self.inference_encoder(
140
+ x,
141
+ mask_x,
142
+ encoded_absolute,
143
+ encoded_map,
144
+ mask_map,
145
+ )
146
+ z_sample, weights = latent_distribution.sample(n_samples=n_samples)
147
+
148
+ mask_z = mask_x.any(-1)
149
+ y_sample = self.decoder(
150
+ z_sample, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map, offset
151
+ )
152
+
153
+ return y_sample, weights, latent_distribution
154
+
155
+ def decode(
156
+ self,
157
+ z_samples: torch.Tensor,
158
+ mask_z: torch.Tensor,
159
+ x: torch.Tensor,
160
+ mask_x: torch.Tensor,
161
+ map: torch.Tensor,
162
+ mask_map: torch.Tensor,
163
+ offset: torch.Tensor,
164
+ ):
165
+ """Returns predicted y values conditionned on z_samples and the other observations.
166
+
167
+ Args:
168
+ z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of latent samples
169
+ mask_z: (batch_size, num_agents) bool mask
170
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
171
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
172
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
173
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
174
+ offset : (batch_size, num_agents, state_dim) offset position from ego.
175
+ """
176
+ encoded_map = self.map_encoder(map, mask_map)
177
+ mask_map = mask_map.any(-1)
178
+ encoded_absolute = self.absolute_encoder(offset)
179
+
180
+ return self.decoder(
181
+ z_samples=z_samples,
182
+ mask_z=mask_z,
183
+ x=x,
184
+ mask_x=mask_x,
185
+ encoded_absolute=encoded_absolute,
186
+ encoded_map=encoded_map,
187
+ mask_map=mask_map,
188
+ offset=offset,
189
+ )
190
+
191
+
192
+ class TrainingBiasedCVAE(InferenceBiasedCVAE):
193
+
194
+ """CVAE with a biased encoder module for risk-biased trajectory forecasting.
195
+ This module is as a non-sampling-based version of BiasedLatentCVAE.
196
+
197
+ Args:
198
+ absolute_encoder: encoder model for the absolute positions of the agents
199
+ map_encoder: encoder model for map objects
200
+ biased_encoder: biased encoder that uses past and auxiliary input,
201
+ inference_encoder: inference encoder that uses only past,
202
+ decoder: CVAE decoder model
203
+ future_encoder: training encoder that uses past and future,
204
+ cost_function: cost function used to compute the risk objective
205
+ risk_estimator: risk estimator used to compute the risk objective
206
+ prior_distribution: prior distribution for the latent space.
207
+ training_mode (optional): set to "cvae" to train the unbiased model, set to "bias" to train
208
+ the biased encoder. Defaults to "cvae".
209
+ latent_regularization (optional): regularization term for the latent space. Defaults to 0.
210
+ risk_assymetry_factor (optional): risk asymmetry factor used to compute the risk objective avoiding underestimations.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ absolute_encoder: MLP,
216
+ map_encoder: MapEncoderNN,
217
+ biased_encoder: CVAEEncoder,
218
+ inference_encoder: CVAEEncoder,
219
+ decoder: CVAEAccelerationDecoder,
220
+ future_encoder: CVAEEncoder,
221
+ cost_function: BaseCostTorch,
222
+ risk_estimator: AbstractMonteCarloRiskEstimator,
223
+ prior_distribution: AbstractLatentDistribution,
224
+ training_mode: str = "cvae",
225
+ latent_regularization: float = 0.0,
226
+ risk_assymetry_factor: float = 100.0,
227
+ ) -> None:
228
+ super().__init__(
229
+ absolute_encoder,
230
+ map_encoder,
231
+ biased_encoder,
232
+ inference_encoder,
233
+ decoder,
234
+ prior_distribution,
235
+ )
236
+ self.future_encoder = future_encoder
237
+ self._cost = cost_function
238
+ self._risk = risk_estimator
239
+ self.set_training_mode(training_mode)
240
+ self.regularization_factor = latent_regularization
241
+ self.risk_assymetry_factor = risk_assymetry_factor
242
+
243
+ def cvae_parameters(self, recurse: bool = True):
244
+ yield from super().cvae_parameters(recurse)
245
+ yield from self.future_encoder.parameters(recurse)
246
+
247
+ def get_parameters(self, recurse: bool = True):
248
+ """Returns a list of two parameter iterators: cvae and encoder only."""
249
+ return [
250
+ self.cvae_parameters(recurse),
251
+ self.biased_parameters(recurse),
252
+ ]
253
+
254
+ def set_training_mode(self, training_mode: str) -> None:
255
+ """
256
+ Change the training mode (get_loss function will be different depending on the mode).
257
+
258
+ Warning: This does not freeze the decoder because the gradient must pass through it.
259
+ The decoder should be frozen at the optimizer level when changing mode.
260
+ """
261
+ assert training_mode in ["cvae", "bias"]
262
+ self.training_mode = training_mode
263
+ if training_mode == "cvae":
264
+ self.get_loss = self.get_loss_cvae
265
+ else:
266
+ self.get_loss = self.get_loss_biased
267
+
268
+ def forward_future(
269
+ self,
270
+ x: torch.Tensor,
271
+ mask_x: torch.Tensor,
272
+ map: torch.Tensor,
273
+ mask_map: torch.Tensor,
274
+ y: torch.Tensor,
275
+ mask_y: torch.Tensor,
276
+ offset: torch.Tensor,
277
+ return_inference: bool = False,
278
+ ) -> Union[
279
+ Tuple[torch.Tensor, AbstractLatentDistribution],
280
+ Tuple[torch.Tensor, AbstractLatentDistribution, AbstractLatentDistribution],
281
+ ]:
282
+ """Forward function that outputs a noisy reconstruction of y and parameters of latent
283
+ posterior distribution
284
+
285
+ Args:
286
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
287
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
288
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
289
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
290
+ y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
291
+ mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask.
292
+ offset: (batch_size, num_agents, state_dim) offset position from ego.
293
+ return_inference: (optional) Set to true if z_mean_inference and z_log_std_inference should be returned, Defaults to None.
294
+
295
+ Returns:
296
+ noisy reconstruction y of size (batch_size, num_agents, num_steps_future, state_dim), and the
297
+ distribution of the latent posterior, as well as, optionally, the distribution of the latent inference posterior.
298
+ """
299
+
300
+ encoded_map = self.map_encoder(map, mask_map)
301
+ mask_map = mask_map.any(-1)
302
+ encoded_absolute = self.absolute_encoder(offset)
303
+
304
+ latent_distribution = self.future_encoder(
305
+ x,
306
+ mask_x,
307
+ y=y,
308
+ mask_y=mask_y,
309
+ encoded_absolute=encoded_absolute,
310
+ encoded_map=encoded_map,
311
+ mask_map=mask_map,
312
+ )
313
+ z_sample, weights = latent_distribution.sample()
314
+ mask_z = mask_x.any(-1)
315
+
316
+ y_sample = self.decoder(
317
+ z_sample,
318
+ mask_z,
319
+ x,
320
+ mask_x,
321
+ encoded_absolute,
322
+ encoded_map,
323
+ mask_map,
324
+ offset,
325
+ )
326
+
327
+ if return_inference:
328
+ inference_distribution = self.inference_encoder(
329
+ x,
330
+ mask_x,
331
+ encoded_absolute,
332
+ encoded_map,
333
+ mask_map,
334
+ )
335
+
336
+ return (
337
+ y_sample,
338
+ latent_distribution,
339
+ inference_distribution,
340
+ )
341
+ else:
342
+ return y_sample, latent_distribution
343
+
344
+ def get_loss_cvae(
345
+ self,
346
+ x: torch.Tensor,
347
+ mask_x: torch.Tensor,
348
+ map: torch.Tensor,
349
+ mask_map: torch.Tensor,
350
+ y: torch.Tensor,
351
+ *,
352
+ mask_y: torch.Tensor,
353
+ mask_loss: torch.Tensor,
354
+ offset: torch.Tensor,
355
+ unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
356
+ kl_weight: float,
357
+ kl_threshold: float,
358
+ **kwargs,
359
+ ) -> Tuple[torch.Tensor, dict]:
360
+ """Compute and return risk-biased CVAE loss averaged over batch and sequence time steps,
361
+ along with desired loss-related metrics for logging
362
+
363
+ Args:
364
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
365
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
366
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
367
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
368
+ y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
369
+ mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask.
370
+ mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
371
+ should be computed and to False where it shouldn't
372
+ offset : (batch_size, num_agents, state_dim) offset position from ego.
373
+ unnormalizer: function that takes in a trajectory and an offset and that outputs the
374
+ unnormalized trajectory
375
+ kl_weight: weight to apply to the KL loss (normal value is 1.0, larger values can be
376
+ used for disentanglement)
377
+ kl_threshold: minimum float value threshold applied to the KL loss
378
+
379
+ Returns:
380
+ torch.Tensor: (1,) loss tensor
381
+ dict: dict that contains loss-related metrics to be logged
382
+ """
383
+ log_dict = dict()
384
+
385
+ if not mask_loss.any():
386
+ warn("A batch is dropped because the whole loss is masked.")
387
+ return torch.zeros(1, requires_grad=True), {}
388
+
389
+ mask_z = mask_x.any(-1)
390
+ # sum_mask_z = mask_z.float().sum().clamp_min(1)
391
+
392
+ (y_sample, latent_distribution, inference_distribution) = self.forward_future(
393
+ x,
394
+ mask_x,
395
+ map,
396
+ mask_map,
397
+ y,
398
+ mask_y,
399
+ offset,
400
+ return_inference=True,
401
+ )
402
+
403
+ # sum_mask_z *= latent_distribution.mu.shape[-1]
404
+
405
+ # log_dict["latent/abs_mean"] = (
406
+ # (latent_distribution.mu.abs() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
407
+ # ).item()
408
+ # log_dict["latent/std"] = (
409
+ # (latent_distribution.logvar.exp() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
410
+ # ).item()
411
+ log_dict["fde/encoded"] = FDE(
412
+ unnormalizer(y_sample, offset), unnormalizer(y, offset), mask_loss
413
+ ).item()
414
+ rec_loss = reconstruction_loss(y_sample, y, mask_loss)
415
+
416
+ kl_loss = latent_distribution.kl_loss(
417
+ inference_distribution,
418
+ kl_threshold,
419
+ mask_z,
420
+ )
421
+
422
+ # self.prior_distribution.to(latent_distribution.mu.device)
423
+
424
+ kl_loss_prior = latent_distribution.kl_loss(
425
+ self.prior_distribution,
426
+ kl_threshold,
427
+ mask_z,
428
+ )
429
+
430
+ sampling_loss = latent_distribution.sampling_loss()
431
+
432
+ log_dict["loss/rec"] = rec_loss.item()
433
+ log_dict["loss/kl"] = kl_loss.item()
434
+ log_dict["loss/kl_prior"] = kl_loss_prior.item()
435
+ log_dict["loss/sampling"] = sampling_loss.item()
436
+ log_dict.update(latent_distribution.log_dict("future"))
437
+ log_dict.update(inference_distribution.log_dict("inference"))
438
+
439
+ loss = (
440
+ rec_loss
441
+ + kl_weight * kl_loss
442
+ + self.regularization_factor * kl_loss_prior
443
+ + sampling_loss
444
+ )
445
+
446
+ log_dict["loss/total"] = loss.item()
447
+
448
+ return loss, log_dict
449
+
450
+ def get_loss_biased(
451
+ self,
452
+ x: torch.Tensor,
453
+ mask_x: torch.Tensor,
454
+ map: torch.Tensor,
455
+ mask_map: torch.Tensor,
456
+ y: torch.Tensor,
457
+ *,
458
+ mask_loss: torch.Tensor,
459
+ offset: torch.Tensor,
460
+ unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
461
+ risk_level: torch.Tensor,
462
+ x_ego: torch.Tensor,
463
+ y_ego: torch.Tensor,
464
+ kl_weight: float,
465
+ kl_threshold: float,
466
+ risk_weight: float,
467
+ n_samples_risk: int,
468
+ n_samples_biased: int,
469
+ dt: float,
470
+ **kwargs,
471
+ ) -> Tuple[torch.Tensor, dict]:
472
+ """Compute and return risk-biased CVAE loss averaged over batch and sequence time steps,
473
+ along with desired loss-related metrics for logging
474
+
475
+ Args:
476
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
477
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
478
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
479
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
480
+ y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
481
+ mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
482
+ should be computed and to False where it shouldn't
483
+ offset : (batch_size, num_agents, state_dim) offset position from ego.
484
+ unnormalizer: function that takes in a trajectory and an offset and that outputs the
485
+ unnormalized trajectory
486
+ risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories
487
+ x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history
488
+ y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory
489
+ kl_weight: weight to apply to the KL loss (normal value is 1.0, larger values can be
490
+ used for disentanglement)
491
+ kl_threshold: minimum float value threshold applied to the KL loss
492
+ risk_weight: weight to apply to the risk loss (beta parameter in our document)
493
+ n_samples_risk: number of sample to use for Monte-Carlo estimation of the risk using the unbiased distribution
494
+ n_samples_biased: number of sample to use for Monte-Carlo estimation of the risk using the biased distribution
495
+ dt: time step in trajectories
496
+
497
+ Returns:
498
+ torch.Tensor: (1,) loss tensor
499
+ dict: dict that contains loss-related metrics to be logged
500
+ """
501
+ log_dict = dict()
502
+
503
+ if not mask_loss.any():
504
+ warn("A batch is dropped because the whole loss is masked.")
505
+ return torch.zeros(1, requires_grad=True), {}
506
+
507
+ mask_z = mask_x.any(-1)
508
+
509
+ # Computing unbiased samples
510
+ n_samples_risk = max(1, n_samples_risk)
511
+ n_samples_biased = max(1, n_samples_biased)
512
+ cost = []
513
+ weights = []
514
+ pack_size = min(n_samples_risk, n_samples_biased)
515
+ with torch.no_grad():
516
+ encoded_map = self.map_encoder(map, mask_map)
517
+ mask_map = mask_map.any(-1)
518
+ encoded_absolute = self.absolute_encoder(offset)
519
+
520
+ inference_distribution = self.inference_encoder(
521
+ x,
522
+ mask_x,
523
+ encoded_absolute,
524
+ encoded_map,
525
+ mask_map,
526
+ )
527
+ for _ in range(n_samples_risk // pack_size):
528
+ z_samples, w = inference_distribution.sample(
529
+ n_samples=pack_size,
530
+ )
531
+
532
+ y_samples = self.decoder(
533
+ z_samples=z_samples,
534
+ mask_z=mask_z,
535
+ x=x,
536
+ mask_x=mask_x,
537
+ encoded_absolute=encoded_absolute,
538
+ encoded_map=encoded_map,
539
+ mask_map=mask_map,
540
+ offset=offset,
541
+ )
542
+
543
+ mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=pack_size)
544
+ # Computing unbiased cost
545
+ cost.append(
546
+ get_cost(
547
+ self._cost,
548
+ x,
549
+ y_samples,
550
+ offset,
551
+ x_ego,
552
+ y_ego,
553
+ dt,
554
+ unnormalizer,
555
+ mask_loss_samples,
556
+ )
557
+ )
558
+ weights.append(w)
559
+
560
+ cost = torch.cat(cost, 2)
561
+ weights = torch.cat(weights, 2)
562
+ risk_cost = self._risk(risk_level, cost, weights)
563
+
564
+ log_dict["fde/prior"] = FDE(
565
+ unnormalizer(y_samples, offset),
566
+ unnormalizer(y, offset).unsqueeze(-3),
567
+ mask_loss_samples,
568
+ ).item()
569
+
570
+ mask_cost_samples = repeat(mask_z, "b a -> b a s", s=n_samples_risk)
571
+ mean_cost = (cost * mask_cost_samples.float() * weights).sum(2) / (
572
+ (mask_cost_samples.float() * weights).sum(2).clamp_min(1)
573
+ )
574
+ log_dict["cost/mean"] = (
575
+ (mean_cost * mask_loss.any(-1).float()).sum()
576
+ / (mask_loss.any(-1).float().sum())
577
+ ).item()
578
+
579
+ # Computing biased latent parameters
580
+ biased_distribution = self.biased_encoder(
581
+ x,
582
+ mask_x,
583
+ encoded_absolute.detach(),
584
+ encoded_map.detach(),
585
+ mask_map,
586
+ risk_level=risk_level,
587
+ x_ego=x_ego,
588
+ y_ego=y_ego,
589
+ offset=offset,
590
+ )
591
+ biased_distribution = inference_distribution.average(
592
+ biased_distribution, risk_level.unsqueeze(-1)
593
+ )
594
+
595
+ # sum_mask_z = mask_z.float().sum().clamp_min(1)* biased_distribution.mu.shape[-1]
596
+ # log_dict["latent/abs_mean_biased"] = (
597
+ # (biased_distribution.mu.abs() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
598
+ # ).item()
599
+ # log_dict["latent/var_biased"] = (
600
+ # (biased_distribution.logvar.exp() * mask_z.unsqueeze(-1).float()).sum() / sum_mask_z
601
+ # ).item()
602
+
603
+ # Computing biased samples
604
+ z_biased_samples, weights = biased_distribution.sample(
605
+ n_samples=n_samples_biased
606
+ )
607
+ mask_z_samples = repeat(mask_z, "b a -> b a s ()", s=n_samples_biased)
608
+ log_dict["latent/abs_samples_biased"] = (
609
+ (z_biased_samples.abs() * mask_z_samples.float()).sum()
610
+ / (mask_z_samples.float().sum())
611
+ ).item()
612
+
613
+ y_biased_samples = self.decoder(
614
+ z_samples=z_biased_samples,
615
+ mask_z=mask_z,
616
+ x=x,
617
+ mask_x=mask_x,
618
+ encoded_absolute=encoded_absolute,
619
+ encoded_map=encoded_map,
620
+ mask_map=mask_map,
621
+ offset=offset,
622
+ )
623
+
624
+ log_dict["fde/prior_biased"] = FDE(
625
+ unnormalizer(y_biased_samples, offset),
626
+ unnormalizer(y, offset).unsqueeze(2),
627
+ mask_loss=mask_loss_samples,
628
+ ).item()
629
+
630
+ # Computing biased cost
631
+ biased_cost = get_cost(
632
+ self._cost,
633
+ x,
634
+ y_biased_samples,
635
+ offset,
636
+ x_ego,
637
+ y_ego,
638
+ dt,
639
+ unnormalizer,
640
+ mask_loss_samples,
641
+ )
642
+ mask_cost_samples = mask_z_samples.squeeze(-1)
643
+ mean_biased_cost = (biased_cost * mask_cost_samples.float() * weights).sum(
644
+ 2
645
+ ) / ((mask_cost_samples.float() * weights).sum(2).clamp_min(1))
646
+ log_dict["cost/mean_biased"] = (
647
+ (mean_biased_cost * mask_loss.any(-1).float()).sum()
648
+ / (mask_loss.any(-1).float().sum())
649
+ ).item()
650
+
651
+ log_dict["cost/risk"] = (
652
+ (risk_cost * mask_loss.any(-1).float()).sum()
653
+ / (mask_loss.any(-1).float().sum())
654
+ ).item()
655
+
656
+ # Computing loss between risk and biased cost
657
+ risk_loss = risk_loss_function(
658
+ mean_biased_cost,
659
+ risk_cost.detach(),
660
+ mask_loss.any(-1),
661
+ self.risk_assymetry_factor,
662
+ )
663
+ log_dict["loss/risk"] = risk_loss.item()
664
+
665
+ # Computing KL loss between prior and biased latent
666
+ kl_loss = inference_distribution.kl_loss(
667
+ biased_distribution,
668
+ kl_threshold,
669
+ mask_z=mask_z,
670
+ )
671
+ log_dict["loss/kl"] = kl_loss.item()
672
+
673
+ loss = risk_weight * risk_loss + kl_weight * kl_loss
674
+ log_dict["loss/total"] = loss.item()
675
+
676
+ log_dict["loss/risk_weight"] = risk_weight
677
+ log_dict.update(inference_distribution.log_dict("inference"))
678
+ log_dict.update(biased_distribution.log_dict("biased"))
679
+
680
+ return loss, log_dict
681
+
682
+ def get_prediction_accuracy(
683
+ self,
684
+ x: torch.Tensor,
685
+ mask_x: torch.Tensor,
686
+ map: torch.Tensor,
687
+ mask_map: torch.Tensor,
688
+ y: torch.Tensor,
689
+ mask_loss: torch.Tensor,
690
+ x_ego: torch.Tensor,
691
+ y_ego: torch.Tensor,
692
+ offset: torch.Tensor,
693
+ unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
694
+ risk_level: torch.Tensor,
695
+ num_samples_min_fde: int = 0,
696
+ ) -> dict:
697
+ """
698
+ A function that calls the predict method and returns a dict that contains prediction
699
+ metrics, which measure accuracy with respect to ground-truth future trajectory y
700
+ Args:
701
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
702
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
703
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
704
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
705
+ y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
706
+ mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
707
+ should be computed and to False where it shouldn't
708
+ x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history
709
+ y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory
710
+ offset: (batch_size, num_agents, state_dim) offset position from ego
711
+
712
+ unnormalizer: function that takes in a trajectory and an offset and that outputs the
713
+ unnormalized trajectory
714
+ risk_level: (batch_size, num_agents) tensor of risk levels desired for future trajectories
715
+ num_samples_min_fde: number of samples to use when computing the minimum final displacement error
716
+ Returns:
717
+ dict: dict that contains prediction-related metrics to be logged
718
+ """
719
+ log_dict = dict()
720
+ with torch.no_grad():
721
+ batch_size = x.shape[0]
722
+ beg = 0
723
+ y_predict = []
724
+
725
+ # Limit the batch size so the num_samples_min_fde value does not impact the memory usage
726
+ for i in range(batch_size // num_samples_min_fde + 1):
727
+ sub_batch_size = num_samples_min_fde
728
+ end = beg + sub_batch_size
729
+
730
+ y_predict.append(
731
+ unnormalizer(
732
+ self.forward(
733
+ x=x[beg:end],
734
+ mask_x=mask_x[beg:end],
735
+ map=map[beg:end],
736
+ mask_map=mask_map[beg:end],
737
+ offset=offset[beg:end],
738
+ x_ego=x_ego[beg:end],
739
+ y_ego=y_ego[beg:end],
740
+ risk_level=None,
741
+ n_samples=num_samples_min_fde,
742
+ )[0],
743
+ offset[beg:end],
744
+ )
745
+ )
746
+ beg = end
747
+ if beg >= batch_size:
748
+ break
749
+
750
+ # Limit the batch size so the num_samples_min_fde value does not impact the memory usage
751
+ if risk_level is not None:
752
+ y_predict_biased = []
753
+ beg = 0
754
+ for i in range(batch_size // num_samples_min_fde + 1):
755
+ sub_batch_size = num_samples_min_fde
756
+ end = beg + sub_batch_size
757
+ y_predict_biased.append(
758
+ unnormalizer(
759
+ self.forward(
760
+ x=x[beg:end],
761
+ mask_x=mask_x[beg:end],
762
+ map=map[beg:end],
763
+ mask_map=mask_map[beg:end],
764
+ offset=offset[beg:end],
765
+ x_ego=x_ego[beg:end],
766
+ y_ego=y_ego[beg:end],
767
+ risk_level=risk_level[beg:end],
768
+ n_samples=num_samples_min_fde,
769
+ )[0],
770
+ offset[beg:end],
771
+ )
772
+ )
773
+ beg = end
774
+ if beg >= batch_size:
775
+ break
776
+ y_predict_biased = torch.cat(y_predict_biased, 0)
777
+ if num_samples_min_fde > 0:
778
+ repeated_mask_loss = repeat(
779
+ mask_loss, "b a t -> b a samples t", samples=num_samples_min_fde
780
+ )
781
+ log_dict["fde/prior_biased"] = FDE(
782
+ y_predict_biased, y.unsqueeze(-3), mask_loss=repeated_mask_loss
783
+ ).item()
784
+ log_dict["minfde/prior_biased"] = minFDE(
785
+ y_predict_biased, y.unsqueeze(-3), mask_loss=repeated_mask_loss
786
+ ).item()
787
+ else:
788
+ log_dict["fde/prior_biased"] = FDE(
789
+ y_predict_biased, y, mask_loss=mask_loss
790
+ ).item()
791
+
792
+ y_predict = torch.cat(y_predict, 0)
793
+ y_unnormalized = unnormalizer(y, offset)
794
+ if num_samples_min_fde > 0:
795
+ repeated_mask_loss = repeat(
796
+ mask_loss, "b a t -> b a samples t", samples=num_samples_min_fde
797
+ )
798
+ log_dict["fde/prior"] = FDE(
799
+ y_predict, y_unnormalized.unsqueeze(-3), mask_loss=repeated_mask_loss
800
+ ).item()
801
+ log_dict["minfde/prior"] = minFDE(
802
+ y_predict, y_unnormalized.unsqueeze(-3), mask_loss=repeated_mask_loss
803
+ ).item()
804
+ else:
805
+ log_dict["fde/prior"] = FDE(
806
+ y_predict, y_unnormalized, mask_loss=mask_loss
807
+ ).item()
808
+ return log_dict
809
+
810
+
811
+ def cvae_factory(
812
+ params: CVAEParams,
813
+ cost_function: BaseCostTorch,
814
+ risk_estimator: AbstractMonteCarloRiskEstimator,
815
+ training_mode: str = "cvae",
816
+ ):
817
+ """Biased CVAE with a biased MLP encoder and an MLP decoder
818
+ Args:
819
+ params: dataclass defining the necessary parameters
820
+ cost_function: cost function used to compute the risk objective
821
+ risk_estimator: risk estimator used to compute the risk objective
822
+ training_mode: "inference", "cvae" or "bias" set what is the training mode
823
+ latent_distribution: "gaussian" or "quantized" set the latent distribution
824
+ """
825
+
826
+ absolute_encoder_nn = MLP(
827
+ params.dynamic_state_dim,
828
+ params.hidden_dim,
829
+ params.hidden_dim,
830
+ params.num_hidden_layers,
831
+ params.is_mlp_residual,
832
+ )
833
+
834
+ map_encoder_nn = MapEncoderNN(params)
835
+
836
+ if params.latent_distribution == "gaussian":
837
+ latent_distribution_creator = GaussianLatentDistribution
838
+ prior_distribution = GaussianLatentDistribution(
839
+ torch.zeros(1, 1, 2 * params.latent_dim)
840
+ )
841
+ future_encoder_latent_dim = 2 * params.latent_dim
842
+ inference_encoder_latent_dim = 2 * params.latent_dim
843
+ biased_encoder_latent_dim = 2 * params.latent_dim
844
+ elif params.latent_distribution == "quantized":
845
+ latent_distribution_creator = QuantizedDistributionCreator(
846
+ params.latent_dim, params.num_vq
847
+ )
848
+ prior_distribution = latent_distribution_creator(
849
+ torch.zeros(1, 1, params.num_vq)
850
+ )
851
+ future_encoder_latent_dim = params.latent_dim
852
+ inference_encoder_latent_dim = params.num_vq
853
+ biased_encoder_latent_dim = params.num_vq
854
+
855
+ biased_encoder_nn = BiasedEncoderNN(
856
+ params,
857
+ biased_encoder_latent_dim,
858
+ num_steps=params.num_steps,
859
+ )
860
+ biased_encoder = CVAEEncoder(
861
+ biased_encoder_nn, latent_distribution_creator=latent_distribution_creator
862
+ )
863
+
864
+ future_encoder_nn = FutureEncoderNN(
865
+ params, future_encoder_latent_dim, params.num_steps + params.num_steps_future
866
+ )
867
+ future_encoder = CVAEEncoder(
868
+ future_encoder_nn, latent_distribution_creator=latent_distribution_creator
869
+ )
870
+
871
+ inference_encoder_nn = InferenceEncoderNN(
872
+ params, inference_encoder_latent_dim, params.num_steps
873
+ )
874
+ inference_encoder = CVAEEncoder(
875
+ inference_encoder_nn, latent_distribution_creator=latent_distribution_creator
876
+ )
877
+
878
+ decoder_nn = DecoderNN(params)
879
+ decoder = CVAEAccelerationDecoder(decoder_nn)
880
+ # decoder = CVAEParametrizedDecoder(decoder_nn)
881
+
882
+ if training_mode == "inference":
883
+ cvae = InferenceBiasedCVAE(
884
+ absolute_encoder_nn,
885
+ map_encoder_nn,
886
+ biased_encoder,
887
+ inference_encoder,
888
+ decoder,
889
+ prior_distribution=prior_distribution,
890
+ )
891
+ cvae.eval()
892
+ return cvae
893
+ else:
894
+ return TrainingBiasedCVAE(
895
+ absolute_encoder_nn,
896
+ map_encoder_nn,
897
+ biased_encoder,
898
+ inference_encoder,
899
+ decoder,
900
+ future_encoder=future_encoder,
901
+ cost_function=cost_function,
902
+ risk_estimator=risk_estimator,
903
+ training_mode=training_mode,
904
+ latent_regularization=params.latent_regularization,
905
+ risk_assymetry_factor=params.risk_assymetry_factor,
906
+ prior_distribution=prior_distribution,
907
+ )
risk_biased/models/context_gating.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from risk_biased.models.mlp import MLP
4
+
5
+
6
+ def pool(x, dim):
7
+ x, _ = x.max(dim)
8
+ return x
9
+
10
+
11
+ class ContextGating(nn.Module):
12
+ """Inspired by Multi-Path++ https://arxiv.org/pdf/2111.14973v3.pdf (but not the same)
13
+
14
+ Args:
15
+ d_model: input dimension of the model
16
+ d: hidden dimension of the model
17
+ num_layers: number of layers of the MLP blocks
18
+ is_mlp_residual: whether to use residual connections in the MLP blocks
19
+ """
20
+
21
+ def __init__(self, d_model, d, num_layers, is_mlp_residual):
22
+ super().__init__()
23
+
24
+ self.w_s = MLP(d_model, d, int((d_model + d) / 2), num_layers, is_mlp_residual)
25
+ self.w_c_cross = MLP(
26
+ d_model, d, int((d_model + d) / 2), num_layers, is_mlp_residual
27
+ )
28
+ self.w_c_global = MLP(d, d, d, num_layers, is_mlp_residual)
29
+
30
+ self.output_layer = nn.Linear(d, d_model)
31
+
32
+ def forward(self, s, c_cross, c_global):
33
+ """context gating forward function
34
+
35
+ Args:
36
+
37
+ s: (batch, agents, features) tensor of agent encoded states
38
+ c_cross: (batch, objects, features) tensor of objects encoded states
39
+ c_global: (batch, d) tensor of global context
40
+
41
+ Returns:
42
+
43
+ s: (batch, agents, features) updated tensor of agent encoded states
44
+ c_global: updated tensor of global context
45
+
46
+ """
47
+ s = self.w_s(s)
48
+ c_cross = self.w_c_cross(c_cross)
49
+ c_global = pool(c_cross, -2) * self.w_c_global(c_global)
50
+ # b: batch, a: agents, k: features
51
+ s = torch.einsum("bak,bk->bak", [s, c_global])
52
+ s = self.output_layer(s)
53
+ return s, c_global
risk_biased/models/cvae_decoder.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops import rearrange, repeat
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from risk_biased.models.cvae_params import CVAEParams
6
+ from risk_biased.models.nn_blocks import (
7
+ MCG,
8
+ MAB,
9
+ MHB,
10
+ SequenceDecoderLSTM,
11
+ SequenceDecoderMLP,
12
+ SequenceEncoderLSTM,
13
+ SequenceEncoderMLP,
14
+ SequenceEncoderMaskedLSTM,
15
+ )
16
+
17
+
18
+ class DecoderNN(nn.Module):
19
+ """Decoder neural network that decodes input tensors into a single output tensor.
20
+ It contains an interaction layer that (re-)compute the interactions between the agents in the scene.
21
+ This implies that a given latent sample for one agent will be affecting the predictions of the othe agents too.
22
+
23
+ Args:
24
+ params: dataclass defining the necessary parameters
25
+
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ params: CVAEParams,
31
+ ) -> None:
32
+ super().__init__()
33
+ self.dt = params.dt
34
+ self.state_dim = params.state_dim
35
+ self.dynamic_state_dim = params.dynamic_state_dim
36
+ self.hidden_dim = params.hidden_dim
37
+ self.num_steps_future = params.num_steps_future
38
+ self.latent_dim = params.latent_dim
39
+
40
+ if params.sequence_encoder_type == "MLP":
41
+ self._agent_encoder_past = SequenceEncoderMLP(
42
+ params.state_dim,
43
+ params.hidden_dim,
44
+ params.num_hidden_layers,
45
+ params.num_steps,
46
+ params.is_mlp_residual,
47
+ )
48
+ elif params.sequence_encoder_type == "LSTM":
49
+ self._agent_encoder_past = SequenceEncoderLSTM(
50
+ params.state_dim, params.hidden_dim
51
+ )
52
+ elif params.sequence_encoder_type == "maskedLSTM":
53
+ self._agent_encoder_past = SequenceEncoderMaskedLSTM(
54
+ params.state_dim, params.hidden_dim
55
+ )
56
+ else:
57
+ raise RuntimeError(
58
+ f"Got sequence encoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' "
59
+ )
60
+
61
+ self._combine_z_past = nn.Linear(
62
+ params.hidden_dim + params.latent_dim, params.hidden_dim
63
+ )
64
+
65
+ if params.interaction_type == "Attention" or params.interaction_type == "MAB":
66
+ self._interaction = MAB(
67
+ params.hidden_dim, params.num_attention_heads, params.num_blocks
68
+ )
69
+ elif (
70
+ params.interaction_type == "ContextGating"
71
+ or params.interaction_type == "MCG"
72
+ ):
73
+ self._interaction = MCG(
74
+ params.hidden_dim,
75
+ params.mcg_dim_expansion,
76
+ params.mcg_num_layers,
77
+ params.num_blocks,
78
+ params.is_mlp_residual,
79
+ )
80
+ elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB":
81
+ self._interaction = MHB(
82
+ params.hidden_dim,
83
+ params.num_attention_heads,
84
+ params.mcg_dim_expansion,
85
+ params.mcg_num_layers,
86
+ params.num_blocks,
87
+ params.is_mlp_residual,
88
+ )
89
+ else:
90
+ self._interaction = lambda x, *args, **kwargs: x
91
+
92
+ if params.sequence_decoder_type == "MLP":
93
+ self._decoder = SequenceDecoderMLP(
94
+ params.hidden_dim,
95
+ params.num_hidden_layers,
96
+ params.num_steps_future,
97
+ params.is_mlp_residual,
98
+ )
99
+ elif params.sequence_decoder_type == "LSTM":
100
+ self._decoder = SequenceDecoderLSTM(params.hidden_dim)
101
+ elif params.sequence_decoder_type == "maskedLSTM":
102
+ self._decoder = SequenceDecoderLSTM(params.hidden_dim)
103
+ else:
104
+ raise RuntimeError(
105
+ f"Got sequence decoder type {params.sequence_decoder_type} but only knows one of: 'MLP', 'LSTM', 'maskedLSTM' "
106
+ )
107
+
108
+ def forward(
109
+ self,
110
+ z_samples: torch.Tensor,
111
+ mask_z: torch.Tensor,
112
+ x: torch.Tensor,
113
+ mask_x: torch.Tensor,
114
+ encoded_absolute: torch.Tensor,
115
+ encoded_map: torch.Tensor,
116
+ mask_map: torch.Tensor,
117
+ ) -> torch.Tensor:
118
+ """Forward function that decodes input tensors into an output tensor of size
119
+ (batch_size, num_agents, (n_samples), num_steps_future, state_dim)
120
+
121
+ Args:
122
+ z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
123
+ mask_z: (batch_size, num_agents) tensor of bool mask
124
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
125
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
126
+ encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
127
+ encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
128
+ mask_map: (batch_size, num_objects) tensor of bool mask
129
+
130
+ Returns:
131
+ (batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor
132
+ """
133
+
134
+ encoded_x = self._agent_encoder_past(x, mask_x)
135
+ squeeze_output_sample_dim = False
136
+ if z_samples.ndim == 3:
137
+ batch_size, num_agents, latent_dim = z_samples.shape
138
+ num_samples = 1
139
+ z_samples = rearrange(z_samples, "b a l -> b a () l")
140
+ squeeze_output_sample_dim = True
141
+ else:
142
+ batch_size, num_agents, num_samples, latent_dim = z_samples.shape
143
+ mask_z = repeat(mask_z, "b a -> (b s) a", s=num_samples)
144
+ mask_map = repeat(mask_map, "b o -> (b s) o", s=num_samples)
145
+ encoded_x = repeat(encoded_x, "b a l -> (b s) a l", s=num_samples)
146
+ encoded_absolute = repeat(
147
+ encoded_absolute, "b a l -> (b s) a l", s=num_samples
148
+ )
149
+ encoded_map = repeat(encoded_map, "b o l -> (b s) o l", s=num_samples)
150
+
151
+ z_samples = rearrange(z_samples, "b a s l -> (b s) a l")
152
+
153
+ h = self._combine_z_past(torch.cat([z_samples, encoded_x], dim=-1))
154
+
155
+ h = self._interaction(h, mask_z, encoded_absolute, encoded_map, mask_map)
156
+
157
+ h = self._decoder(h, self.num_steps_future)
158
+
159
+ if not squeeze_output_sample_dim:
160
+ h = rearrange(h, "(b s) a t l -> b a s t l", b=batch_size, s=num_samples)
161
+
162
+ return h
163
+
164
+
165
+ class CVAEAccelerationDecoder(nn.Module):
166
+ """Decoder architecture for conditional variational autoencoder
167
+
168
+ Args:
169
+ model: decoder neural network that transforms input tensors to an output sequence
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ model: nn.Module,
175
+ ) -> None:
176
+ super().__init__()
177
+ self._model = model
178
+ self._output_layer = nn.Linear(model.hidden_dim, 2)
179
+
180
+ def forward(
181
+ self,
182
+ z_samples: torch.Tensor,
183
+ mask_z: torch.Tensor,
184
+ x: torch.Tensor,
185
+ mask_x: torch.Tensor,
186
+ encoded_absolute: torch.Tensor,
187
+ encoded_map: torch.Tensor,
188
+ mask_map: torch.Tensor,
189
+ offset: torch.Tensor,
190
+ ) -> torch.Tensor:
191
+ """Forward function that decodes input tensors into an output tensor of size
192
+ (batch_size, num_agents, (n_samples), num_steps_future, state_dim=5)
193
+ It first predicts accelerations that are doubly integrated to produce the output
194
+ state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y)
195
+
196
+ Args:
197
+ z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
198
+ mask_z: (batch_size, num_agents) tensor of bool mask
199
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
200
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
201
+ encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
202
+ encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
203
+ mask_map: (batch_size, num_objects) tensor of bool mask
204
+
205
+ Returns:
206
+ (batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension
207
+ does not exist if z_samples is a 2D tensor.
208
+ """
209
+
210
+ h = self._model(
211
+ z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map
212
+ )
213
+ h = self._output_layer(h)
214
+
215
+ dt = self._model.dt
216
+ initial_position = x[..., -1:, :2].clone()
217
+ # If shape is 5 it should be (x, y, angle, vx, vy)
218
+ if offset.shape[-1] == 5:
219
+ initial_velocity = offset[..., 3:5].clone().unsqueeze(-2)
220
+ # else if shape is 4 it should be (x, y, vx, vy)
221
+ elif offset.shape[-1] == 4:
222
+ initial_velocity = offset[..., 2:4].clone().unsqueeze(-2)
223
+ elif x.shape[-1] == 5:
224
+ initial_velocity = x[..., -1:, 3:5].clone()
225
+ elif x.shape[-1] == 4:
226
+ initial_velocity = x[..., -1:, 2:4].clone()
227
+ else:
228
+ initial_velocity = (x[..., -1:, :] - x[..., -2:-1, :]) / dt
229
+
230
+ output = torch.zeros(
231
+ (*h.shape[:-1], self._model.dynamic_state_dim), device=h.device
232
+ )
233
+ # There might be a sample dimension in the output tensor, then adapt the shape of initial position and velocity
234
+ if output.ndim == 5:
235
+ initial_position = initial_position.unsqueeze(-3)
236
+ initial_velocity = initial_velocity.unsqueeze(-3)
237
+
238
+ if self._model.dynamic_state_dim == 5:
239
+ output[..., 3:5] = h.cumsum(-2) * dt
240
+ output[..., :2] = (output[..., 3:5].clone() + initial_velocity).cumsum(
241
+ -2
242
+ ) * dt + initial_position
243
+ output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone())
244
+ elif self._model.dynamic_state_dim == 4:
245
+ output[..., 2:4] = h.cumsum(-2) * dt
246
+ output[..., :2] = (output[..., 2:4].clone() + initial_velocity).cumsum(
247
+ -2
248
+ ) * dt + initial_position
249
+ else:
250
+ velocity = h.cumsum(-2) * dt
251
+ output = (velocity.clone() + initial_velocity).cumsum(
252
+ -2
253
+ ) * dt + initial_position
254
+ return output
255
+
256
+
257
+ class CVAEParametrizedDecoder(nn.Module):
258
+ """Decoder architecture for conditional variational autoencoder
259
+
260
+ Args:
261
+ model: decoder neural network that transforms input tensors to an output sequence
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ model: nn.Module,
267
+ ) -> None:
268
+ super().__init__()
269
+ self._model = model
270
+ self._order = 3
271
+ self._output_layer = nn.Linear(
272
+ model.hidden_dim * model.num_steps_future,
273
+ 2 * self._order + model.num_steps_future,
274
+ )
275
+
276
+ def polynomial(self, x: torch.Tensor, params: torch.Tensor):
277
+ """Polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and
278
+ a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future)
279
+ """
280
+ h = x.clone()
281
+ squeeze = False
282
+ if h.ndim == 3:
283
+ h = h.unsqueeze(2)
284
+ params = params.unsqueeze(2)
285
+ squeeze = True
286
+ h = repeat(
287
+ h,
288
+ "batch agents samples sequence -> batch agents samples sequence two order",
289
+ order=self._order,
290
+ two=2,
291
+ ).cumprod(-1)
292
+ h = h * params.view(*params.shape[:-1], 1, 2, self._order)
293
+ h = h.sum(-1)
294
+ if squeeze:
295
+ h = h.squeeze(2)
296
+ return h
297
+
298
+ def dpolynomial(self, x: torch.Tensor, params: torch.Tensor):
299
+ """Derivative of the polynomial function that takes a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future) and
300
+ a parameter tensor of shape (batch_size, num_agents, (n_samples), self._order*2) and returns a tensor of shape (batch_size, num_agents, (n_samples), num_steps_future)
301
+ """
302
+ h = x.clone()
303
+ squeeze = False
304
+ if h.ndim == 3:
305
+ h = h.unsqueeze(2)
306
+ params = params.unsqueeze(2)
307
+ squeeze = True
308
+ h = repeat(
309
+ h,
310
+ "batch agents samples sequence -> batch agents samples sequence two order",
311
+ order=self._order - 1,
312
+ two=2,
313
+ )
314
+ h = torch.cat((torch.ones_like(h[..., :1]), h.cumprod(-1)), -1)
315
+ h = h * params.view(*params.shape[:-1], 1, 2, self._order)
316
+ h = h * torch.arange(self._order).view(*([1] * params.ndim), -1).to(x.device)
317
+ h = h.sum(-1)
318
+ if squeeze:
319
+ h = h.squeeze(2)
320
+ return h
321
+
322
+ def forward(
323
+ self,
324
+ z_samples: torch.Tensor,
325
+ mask_z: torch.Tensor,
326
+ x: torch.Tensor,
327
+ mask_x: torch.Tensor,
328
+ encoded_absolute: torch.Tensor,
329
+ encoded_map: torch.Tensor,
330
+ mask_map: torch.Tensor,
331
+ offset: torch.Tensor,
332
+ ) -> torch.Tensor:
333
+ """Forward function that decodes input tensors into an output tensor of size
334
+ (batch_size, num_agents, (n_samples), num_steps_future, state_dim=5)
335
+ It first predicts accelerations that are doubly integrated to produce the output
336
+ state sequence with positions angles and velocities (x, y, theta, vx, vy) or (x, y, vx, vy) or (x, y)
337
+
338
+ Args:
339
+ z_samples: (batch_size, num_agents, (n_samples), latent_dim) tensor of history
340
+ mask_z: (batch_size, num_agents) tensor of bool mask
341
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history for all agents
342
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
343
+ encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
344
+ encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
345
+ mask_map: (batch_size, num_objects) tensor of bool mask
346
+
347
+ Returns:
348
+ (batch_size, num_agents, (n_samples), num_steps_future, state_dim) output tensor. Sample dimension
349
+ does not exist if z_samples is a 2D tensor.
350
+ """
351
+
352
+ squeeze_output_sample_dim = z_samples.ndim == 3
353
+ batch_size = z_samples.shape[0]
354
+
355
+ h = self._model(
356
+ z_samples, mask_z, x, mask_x, encoded_absolute, encoded_map, mask_map
357
+ )
358
+ if squeeze_output_sample_dim:
359
+ h = rearrange(
360
+ h, "batch agents sequence features -> batch agents (sequence features)"
361
+ )
362
+ else:
363
+ h = rearrange(
364
+ h,
365
+ "(batch samples) agents sequence features -> batch agents samples (sequence features)",
366
+ batch=batch_size,
367
+ )
368
+ h = self._output_layer(h)
369
+
370
+ output = torch.zeros(
371
+ (
372
+ *h.shape[:-1],
373
+ self._model.num_steps_future,
374
+ self._model.dynamic_state_dim,
375
+ ),
376
+ device=h.device,
377
+ )
378
+ params = h[..., : 2 * self._order]
379
+ dldt = torch.relu(h[..., 2 * self._order :])
380
+ distance = dldt.cumsum(-2)
381
+ output[..., :2] = self.polynomial(distance, params)
382
+ if self._model.dynamic_state_dim == 5:
383
+ output[..., 3:5] = dldt * self.dpolynomial(distance, params)
384
+ output[..., 2] = torch.atan2(output[..., 4].clone(), output[..., 3].clone())
385
+ elif self._model.dynamic_state_dim == 4:
386
+ output[..., 2:4] = dldt * self.dpolynomial(distance, params)
387
+
388
+ return output
risk_biased/models/cvae_encoders.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from risk_biased.models.cvae_params import CVAEParams
8
+ from risk_biased.models.nn_blocks import (
9
+ MCG,
10
+ MAB,
11
+ MHB,
12
+ SequenceEncoderLSTM,
13
+ SequenceEncoderMLP,
14
+ SequenceEncoderMaskedLSTM,
15
+ )
16
+ from risk_biased.models.latent_distributions import AbstractLatentDistribution
17
+
18
+
19
+ class BaseEncoderNN(nn.Module):
20
+ """Base encoder neural network that defines the common functionality of encoders.
21
+ It should not be used directly but rather extended to define specific encoders.
22
+
23
+ Args:
24
+ params: dataclass defining the necessary parameters
25
+ num_steps: length of the input sequence
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ params: CVAEParams,
31
+ latent_dim: int,
32
+ num_steps: int,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.is_mlp_residual = params.is_mlp_residual
36
+ self.num_hidden_layers = params.num_hidden_layers
37
+ self.num_steps = params.num_steps
38
+ self.num_steps_future = params.num_steps_future
39
+ self.sequence_encoder_type = params.sequence_encoder_type
40
+ self.state_dim = params.state_dim
41
+ self.latent_dim = latent_dim
42
+ self.hidden_dim = params.hidden_dim
43
+
44
+ if params.sequence_encoder_type == "MLP":
45
+ self._agent_encoder = SequenceEncoderMLP(
46
+ params.state_dim,
47
+ params.hidden_dim,
48
+ params.num_hidden_layers,
49
+ num_steps,
50
+ params.is_mlp_residual,
51
+ )
52
+ elif params.sequence_encoder_type == "LSTM":
53
+ self._agent_encoder = SequenceEncoderLSTM(
54
+ params.state_dim, params.hidden_dim
55
+ )
56
+ elif params.sequence_encoder_type == "maskedLSTM":
57
+ self._agent_encoder = SequenceEncoderMaskedLSTM(
58
+ params.state_dim, params.hidden_dim
59
+ )
60
+
61
+ if params.interaction_type == "Attention" or params.interaction_type == "MAB":
62
+ self._interaction = MAB(
63
+ params.hidden_dim, params.num_attention_heads, params.num_blocks
64
+ )
65
+ elif (
66
+ params.interaction_type == "ContextGating"
67
+ or params.interaction_type == "MCG"
68
+ ):
69
+ self._interaction = MCG(
70
+ params.hidden_dim,
71
+ params.mcg_dim_expansion,
72
+ params.mcg_num_layers,
73
+ params.num_blocks,
74
+ params.is_mlp_residual,
75
+ )
76
+ elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB":
77
+ self._interaction = MHB(
78
+ params.hidden_dim,
79
+ params.num_attention_heads,
80
+ params.mcg_dim_expansion,
81
+ params.mcg_num_layers,
82
+ params.num_blocks,
83
+ params.is_mlp_residual,
84
+ )
85
+ else:
86
+ self._interaction = lambda x, *args, **kwargs: x
87
+ self._output_layer = nn.Linear(params.hidden_dim, self.latent_dim)
88
+
89
+ def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs):
90
+ raise NotImplementedError
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ mask_x: torch.Tensor,
96
+ encoded_absolute: torch.Tensor,
97
+ encoded_map: torch.Tensor,
98
+ mask_map: torch.Tensor,
99
+ y: Optional[torch.Tensor] = None,
100
+ mask_y: Optional[torch.Tensor] = None,
101
+ x_ego: Optional[torch.Tensor] = None,
102
+ y_ego: Optional[torch.Tensor] = None,
103
+ offset: Optional[torch.Tensor] = None,
104
+ risk_level: Optional[torch.Tensor] = None,
105
+ ) -> torch.Tensor:
106
+ """Forward function that encodes input tensors into an output tensor of dimension
107
+ latent_dim.
108
+
109
+ Args:
110
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
111
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
112
+ encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
113
+ encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
114
+ mask_map: (batch_size, num_objects) tensor of bool mask
115
+ y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
116
+ mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None.
117
+ x_ego: (batch_size, 1, num_steps, state_dim) ego history
118
+ y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
119
+ offset (optional): (batch_size, num_agents, state_dim) offset position from ego.
120
+ risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
121
+ trajectories. Defaults to None.
122
+
123
+ Returns:
124
+ (batch_size, num_agents, latent_dim) output tensor
125
+ """
126
+ h_agents = self.encode_agents(
127
+ x=x,
128
+ mask_x=mask_x,
129
+ y=y,
130
+ mask_y=mask_y,
131
+ x_ego=x_ego,
132
+ y_ego=y_ego,
133
+ offset=offset,
134
+ risk_level=risk_level,
135
+ )
136
+ mask_agent = mask_x.any(-1)
137
+ h_agents = self._interaction(
138
+ h_agents, mask_agent, encoded_absolute, encoded_map, mask_map
139
+ )
140
+
141
+ return self._output_layer(h_agents)
142
+
143
+
144
+ class BiasedEncoderNN(BaseEncoderNN):
145
+ """Biased encoder neural network that encodes past info and auxiliary input
146
+ into a biased distribution over the latent space.
147
+
148
+ Args:
149
+ params: dataclass defining the necessary parameters
150
+ num_steps: length of the input sequence
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ params: CVAEParams,
156
+ latent_dim: int,
157
+ num_steps: int,
158
+ ) -> None:
159
+ super().__init__(params, latent_dim, num_steps)
160
+ self.condition_on_ego_future = params.condition_on_ego_future
161
+ if params.sequence_encoder_type == "MLP":
162
+ self._ego_encoder = SequenceEncoderMLP(
163
+ params.state_dim,
164
+ params.hidden_dim,
165
+ params.num_hidden_layers,
166
+ params.num_steps
167
+ + params.num_steps_future * self.condition_on_ego_future,
168
+ params.is_mlp_residual,
169
+ )
170
+ elif params.sequence_encoder_type == "LSTM":
171
+ self._ego_encoder = SequenceEncoderLSTM(params.state_dim, params.hidden_dim)
172
+ elif params.sequence_encoder_type == "maskedLSTM":
173
+ self._ego_encoder = SequenceEncoderMaskedLSTM(
174
+ params.state_dim, params.hidden_dim
175
+ )
176
+
177
+ self._auxiliary_encode = nn.Linear(
178
+ params.hidden_dim + 1 + params.hidden_dim, params.hidden_dim
179
+ )
180
+
181
+ def biased_parameters(self, recurse: bool = True):
182
+ """Get the parameters to be optimized when training to bias."""
183
+ yield from self.parameters(recurse)
184
+
185
+ def encode_agents(
186
+ self,
187
+ x: torch.Tensor,
188
+ mask_x: torch.Tensor,
189
+ *,
190
+ x_ego: torch.Tensor,
191
+ y_ego: torch.Tensor,
192
+ offset: torch.Tensor,
193
+ risk_level: torch.Tensor,
194
+ **kwargs,
195
+ ):
196
+ """Encode agent input and auxiliary input into a feature vector.
197
+
198
+ Args:
199
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
200
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
201
+ x_ego: (batch_size, 1, num_steps, state_dim) ego history
202
+ y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
203
+ offset: (batch_size, num_agents, state_dim) offset position from ego.
204
+ risk_level: (batch_size, num_agents) tensor of risk levels desired for future
205
+ trajectories. Defaults to None.
206
+ Returns:
207
+ (batch_size, latent_dim) output tensor
208
+ """
209
+
210
+ if self.condition_on_ego_future:
211
+ ego_tensor = torch.cat([x_ego, y_ego], dim=-2)
212
+ else:
213
+ ego_tensor = x_ego
214
+
215
+ risk_feature = ((risk_level - 0.5) * 10).exp().unsqueeze(-1)
216
+ mask_ego = torch.ones(
217
+ ego_tensor.shape[0],
218
+ offset.shape[1],
219
+ ego_tensor.shape[2],
220
+ device=ego_tensor.device,
221
+ )
222
+ batch_size, n_agents, dynamic_state_dim = offset.shape
223
+ state_dim = ego_tensor.shape[-1]
224
+ extended_offset = torch.cat(
225
+ (
226
+ offset,
227
+ torch.zeros(
228
+ batch_size,
229
+ n_agents,
230
+ state_dim - dynamic_state_dim,
231
+ device=offset.device,
232
+ ),
233
+ ),
234
+ dim=-1,
235
+ ).unsqueeze(-2)
236
+ if extended_offset.shape[1] > 1:
237
+ ego_encoded = self._ego_encoder(
238
+ ego_tensor + extended_offset[:, :1] - extended_offset, mask_ego
239
+ )
240
+ else:
241
+ ego_encoded = self._ego_encoder(ego_tensor - extended_offset, mask_ego)
242
+ auxiliary_input = torch.cat((risk_feature, ego_encoded), -1)
243
+
244
+ h_agents = self._agent_encoder(x, mask_x)
245
+ h_agents = torch.cat([h_agents, auxiliary_input], dim=-1)
246
+ h_agents = self._auxiliary_encode(h_agents)
247
+
248
+ return h_agents
249
+
250
+
251
+ class InferenceEncoderNN(BaseEncoderNN):
252
+ """Inference encoder neural network that encodes past info into the
253
+ inference distribution over the latent space.
254
+
255
+ Args:
256
+ params: dataclass defining the necessary parameters
257
+ num_steps: length of the input sequence
258
+ """
259
+
260
+ def biaser_parameters(self, recurse: bool = True):
261
+ yield from []
262
+
263
+ def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs):
264
+ h_agents = self._agent_encoder(x, mask_x)
265
+ return h_agents
266
+
267
+
268
+ class FutureEncoderNN(BaseEncoderNN):
269
+ """Future encoder neural network that encodes past and future info into the
270
+ future-conditioned distribution over the latent space.
271
+ The future is not available at test time, this is only used for training.
272
+
273
+ Args:
274
+ params: dataclass defining the necessary parameters
275
+ num_steps: length of the input sequence
276
+
277
+ """
278
+
279
+ def biaser_parameters(self, recurse: bool = True):
280
+ """The future encoder is not optimized when training to bias."""
281
+ yield from []
282
+
283
+ def encode_agents(
284
+ self,
285
+ x: torch.Tensor,
286
+ mask_x: torch.Tensor,
287
+ *,
288
+ y: torch.Tensor,
289
+ mask_y: torch.Tensor,
290
+ **kwargs,
291
+ ):
292
+ """Encode agent input and future input into a feature vector.
293
+ Args:
294
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of trajectory history
295
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
296
+ y: (batch_size, num_agents, num_steps_future, state_dim) future trajectory
297
+ mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask
298
+ """
299
+ mask_traj = torch.cat([mask_x, mask_y], dim=-1)
300
+ h_agents = self._agent_encoder(torch.cat([x, y], dim=-2), mask_traj)
301
+ return h_agents
302
+
303
+
304
+ class CVAEEncoder(nn.Module):
305
+ """Encoder architecture for conditional variational autoencoder
306
+
307
+ Args:
308
+ model: encoder neural network that transforms input tensors to an unsplitted latent output
309
+ latent_distribution_creator: Class that creates a latent distribution class for the latent space.
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ model: BaseEncoderNN,
315
+ latent_distribution_creator,
316
+ ) -> None:
317
+ super().__init__()
318
+ self._model = model
319
+ self.latent_dim = model.latent_dim
320
+ self._latent_distribution_creator = latent_distribution_creator
321
+
322
+ def biased_parameters(self, recurse: bool = True):
323
+ yield from self._model.biased_parameters(recurse)
324
+
325
+ def forward(
326
+ self,
327
+ x: torch.Tensor,
328
+ mask_x: torch.Tensor,
329
+ encoded_absolute: torch.Tensor,
330
+ encoded_map: torch.Tensor,
331
+ mask_map: torch.Tensor,
332
+ y: Optional[torch.Tensor] = None,
333
+ mask_y: Optional[torch.Tensor] = None,
334
+ x_ego: Optional[torch.Tensor] = None,
335
+ y_ego: Optional[torch.Tensor] = None,
336
+ offset: Optional[torch.Tensor] = None,
337
+ risk_level: Optional[torch.Tensor] = None,
338
+ ) -> AbstractLatentDistribution:
339
+ """Forward function that encodes input tensors into an output tensor of dimension
340
+ latent_dim.
341
+
342
+ Args:
343
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
344
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
345
+ encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
346
+ encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
347
+ mask_map: (batch_size, num_objects) tensor of bool mask
348
+ y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
349
+ mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None.
350
+ x_ego (optional): (batch_size, 1, num_steps, state_dim) ego history
351
+ y_ego (optional): (batch_size, 1, num_steps_future, state_dim) ego future
352
+ offset (optional): (batch_size, num_agents, state_dim) offset position from ego.
353
+ risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
354
+ trajectories. Defaults to None.
355
+
356
+ Returns:
357
+ Latent distribution representing the posterior over the latent variables given the input observations.
358
+ """
359
+
360
+ latent_output = self._model(
361
+ x=x,
362
+ mask_x=mask_x,
363
+ encoded_absolute=encoded_absolute,
364
+ encoded_map=encoded_map,
365
+ mask_map=mask_map,
366
+ y=y,
367
+ mask_y=mask_y,
368
+ x_ego=x_ego,
369
+ y_ego=y_ego,
370
+ offset=offset,
371
+ risk_level=risk_level,
372
+ )
373
+
374
+ latent_distribution = self._latent_distribution_creator(latent_output)
375
+
376
+ return latent_distribution
risk_biased/models/cvae_params.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from mmcv import Config
4
+
5
+
6
+ @dataclass
7
+ class CVAEParams:
8
+ """
9
+ state_dim: Dimension of the state at each time step.
10
+ map_state_dim: Dimension of the map point features at each position.
11
+ num_steps: Number of time steps in the past trajectory input.
12
+ num_steps_future: Number of time steps in the future trajectory output.
13
+ latent_dim: Dimension of the latent space
14
+ hidden_dim: Dimension of the hidden layers
15
+ num_hidden_layers: Number of layers for each model, (encoder, decoder)
16
+ is_mlp_residual: Set to True to add linear transformation of the input to output of the MLP
17
+ interaction_type: Wether to use MCG, MAB, or MHB to handle interactions
18
+ num_attention_heads: Number of attention heads to use in MHA blocks
19
+ mcg_dim_expansion: Dimension expansion factor for the MCG global interaction space
20
+ mcg_num_layers: Number of layers for the MLP MCG blocks
21
+ num_blocks: Number of interaction blocks to use
22
+ sequence_encoder_type: Type of sequence encoder maskedLSTM, LSTM, or MLP
23
+ sequence_decoder_type: Type of sequence decoder maskedLSTM, LSTM, or MLP
24
+ condition_on_ego_future: Wether to condition the biasing with the ego future or only the ego past
25
+ latent_regularization: Weight of the latent regularization loss
26
+ """
27
+
28
+ dt: float
29
+ state_dim: int
30
+ dynamic_state_dim: int
31
+ map_state_dim: int
32
+ max_size_lane: int
33
+ num_steps: int
34
+ num_steps_future: int
35
+ latent_dim: int
36
+ hidden_dim: int
37
+ num_hidden_layers: int
38
+ is_mlp_residual: bool
39
+ interaction_type: int
40
+ num_attention_heads: int
41
+ mcg_dim_expansion: int
42
+ mcg_num_layers: int
43
+ num_blocks: int
44
+ sequence_encoder_type: str
45
+ sequence_decoder_type: str
46
+ condition_on_ego_future: bool
47
+ latent_regularization: float
48
+ risk_assymetry_factor: float
49
+ num_vq: int
50
+ latent_distribution: str
51
+
52
+ @staticmethod
53
+ def from_config(cfg: Config):
54
+ return CVAEParams(
55
+ dt=cfg.dt,
56
+ state_dim=cfg.state_dim,
57
+ dynamic_state_dim=cfg.dynamic_state_dim,
58
+ map_state_dim=cfg.map_state_dim,
59
+ max_size_lane=cfg.max_size_lane,
60
+ num_steps=cfg.num_steps,
61
+ num_steps_future=cfg.num_steps_future,
62
+ latent_dim=cfg.latent_dim,
63
+ hidden_dim=cfg.hidden_dim,
64
+ num_hidden_layers=cfg.num_hidden_layers,
65
+ is_mlp_residual=cfg.is_mlp_residual,
66
+ interaction_type=cfg.interaction_type,
67
+ mcg_dim_expansion=cfg.mcg_dim_expansion,
68
+ mcg_num_layers=cfg.mcg_num_layers,
69
+ num_blocks=cfg.num_blocks,
70
+ num_attention_heads=cfg.num_attention_heads,
71
+ sequence_encoder_type=cfg.sequence_encoder_type,
72
+ sequence_decoder_type=cfg.sequence_decoder_type,
73
+ condition_on_ego_future=cfg.condition_on_ego_future,
74
+ latent_regularization=cfg.latent_regularization,
75
+ risk_assymetry_factor=cfg.risk_assymetry_factor,
76
+ num_vq=cfg.num_vq,
77
+ latent_distribution=cfg.latent_distribution,
78
+ )
risk_biased/models/latent_distributions.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable, Tuple
2
+ import warnings
3
+
4
+ from abc import ABC, abstractmethod
5
+ from einops import rearrange, repeat
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def relaxed_one_hot_categorical_without_replacement(temperature, logits, num_samples=1):
11
+ # See paper Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement (https://arxiv.org/pdf/1903.06059.pdf)
12
+ # for explanation of the trick
13
+ scores = (
14
+ (torch.distributions.Gumbel(logits, 1).rsample() / temperature)
15
+ .softmax(-1)
16
+ .clamp_min(1e-10)
17
+ )
18
+ top_scores, top_indices = torch.topk(
19
+ scores,
20
+ num_samples,
21
+ dim=-1,
22
+ )
23
+ return scores, top_indices
24
+
25
+
26
+ class AbstractLatentDistribution(nn.Module, ABC):
27
+ """Base class for latent distribution"""
28
+
29
+ @abstractmethod
30
+ def sample(
31
+ self, num_samples: int, *args, **kwargs
32
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
33
+ """Sample from the latent distribution."""
34
+
35
+ @abstractmethod
36
+ def kl_loss(
37
+ self,
38
+ other: "GaussianLatentDistribution",
39
+ threshold: float = 0,
40
+ mask_z: Optional[torch.Tensor] = None,
41
+ ) -> torch.Tensor:
42
+ """Compute the KL divergence between two latent distributions."""
43
+
44
+ @abstractmethod
45
+ def sampling_loss(self) -> torch.Tensor:
46
+ """Loss of the latent distribution."""
47
+
48
+ @abstractmethod
49
+ def average(
50
+ self, other: "AbstractLatentDistribution", weight_other: torch.Tensor
51
+ ) -> "AbstractLatentDistribution":
52
+ """Average of the latent distribution."""
53
+
54
+ @abstractmethod
55
+ def log_dict(self, type: str) -> dict:
56
+ """Log the latent distribution values."""
57
+
58
+
59
+ class GaussianLatentDistribution(AbstractLatentDistribution):
60
+ """Gaussian latent distribution"""
61
+
62
+ def __init__(self, latent_representation: torch.Tensor):
63
+ super().__init__()
64
+ mu, logvar = torch.chunk(latent_representation, 2, dim=-1)
65
+ self.register_buffer("mu", mu, False)
66
+ self.register_buffer("logvar", logvar, False)
67
+
68
+ def sample(
69
+ self, n_samples: int = 0, *args, **kwargs
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ """Sample from Gaussian with a reparametrization trick
72
+
73
+ Args:
74
+ n_samples (optional): number of samples to make, (if 0 one sample with no extra
75
+ dimension). Defaults to 0.
76
+ Returns:
77
+ Random Gaussian sample of size (some_shape, (n_samples), latent_dim)
78
+ """
79
+
80
+ std = (self.logvar / 2).exp()
81
+ if n_samples <= 0:
82
+ eps = torch.randn_like(std)
83
+ latent_samples = self.mu + eps * std
84
+ weights = torch.ones_like(latent_samples[..., 0])
85
+ else:
86
+ eps = torch.randn(
87
+ [*std.shape[:-1], n_samples, self.mu.shape[-1]], device=std.device
88
+ )
89
+ # Reshape
90
+ latent_samples = self.mu.unsqueeze(-2) + eps * std.unsqueeze(-2)
91
+ weights = torch.ones_like(latent_samples[..., 0]) / n_samples
92
+ return latent_samples, weights
93
+
94
+ def kl_loss(
95
+ self,
96
+ other: "GaussianLatentDistribution",
97
+ threshold: float = 0,
98
+ mask_z: Optional[torch.Tensor] = None,
99
+ ) -> torch.Tensor:
100
+ """Compute the KL divergence between two latent distributions."""
101
+ assert type(other) == GaussianLatentDistribution
102
+ kl_loss = (
103
+ (other.logvar
104
+ - self.logvar
105
+ + ((self.mu - other.mu).square() + self.logvar.exp()) / other.logvar.exp()
106
+ - 1)*0.5
107
+ ).clamp_min(threshold)
108
+ if mask_z is None:
109
+ return kl_loss.mean()
110
+ else:
111
+ assert mask_z.any()
112
+ return torch.sum(kl_loss.mean(-1) * mask_z) / torch.sum(mask_z)
113
+
114
+ def sampling_loss(self) -> torch.Tensor:
115
+ return torch.zeros(1, device=self.mu.device)
116
+
117
+ def average(
118
+ self, other: "GaussianLatentDistribution", weight_other: torch.Tensor
119
+ ) -> "GaussianLatentDistribution":
120
+ assert type(other) == GaussianLatentDistribution
121
+ assert other.mu.shape == self.mu.shape
122
+ average_log_var = (
123
+ self.logvar.exp() * (1 - weight_other) + other.logvar.exp() * weight_other
124
+ ).log()
125
+ return GaussianLatentDistribution(
126
+ torch.cat(
127
+ (
128
+ self.mu * (1 - weight_other) + other.mu * weight_other,
129
+ average_log_var,
130
+ ),
131
+ dim=-1,
132
+ )
133
+ )
134
+
135
+ def log_dict(self, type: str) -> dict:
136
+ return {
137
+ f"latent/{type}/abs_mean": self.mu.abs().mean(),
138
+ f"latent/{type}/std": (self.logvar * 0.5).exp().mean(),
139
+ }
140
+
141
+
142
+ class QuantizedLatentDistribution(AbstractLatentDistribution):
143
+ """Quantized latent distribution.
144
+ It is defined with a codebook of quantized latents and a continuous latent.
145
+ The distribution is based on distances of the continuous latent to the codebook.
146
+ Sampling is only quantizing the continuous latent.
147
+
148
+ Args:
149
+ continuous_latent : Continuous latent representation of shape (some_shape, latent_dim)
150
+ codebook : Codebook of shape (num_embeddings, latent_dim)
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ continuous_latent: torch.Tensor,
156
+ codebook: torch.Tensor,
157
+ flush_weights: Callable[[], None],
158
+ get_weights: Callable[[], torch.Tensor],
159
+ index_add_one_weights: Callable[[torch.Tensor], None],
160
+ ):
161
+ super().__init__()
162
+ self.register_buffer("continuous_latent", continuous_latent, False)
163
+ self.register_buffer("codebook", codebook, False)
164
+ self.flush_weights = flush_weights
165
+ self.get_weights = get_weights
166
+ self.index_add_one_weights = index_add_one_weights
167
+ self.quantization_loss = None
168
+ self.accuracy = None
169
+
170
+ def sample(
171
+ self, n_samples: int = 0, *args, **kwargs
172
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ """Quantize the continuous latent from the latent dictionary.
174
+
175
+ Args:
176
+ latent: (batch_size, num_agents, latent_dim) Continuous latent input
177
+
178
+ Returns:
179
+ quantized_latent, quantization_loss
180
+ """
181
+ assert n_samples == 0, "Only one sample is supported for quantized latent"
182
+
183
+ distances_to_quantized = (
184
+ (
185
+ self.codebook.view(1, 1, *self.codebook.shape)
186
+ - self.continuous_latent.unsqueeze(-2)
187
+ )
188
+ .square()
189
+ .sum(-1)
190
+ )
191
+ batch_size, num_agents, num_vq = distances_to_quantized.shape
192
+
193
+ self.soft_one_hot = (
194
+ (-100 * distances_to_quantized)
195
+ .softmax(dim=-1)
196
+ .view(batch_size, num_agents, num_vq)
197
+ )
198
+ # quantized, args_selected = self.sample(soft_one_hot)
199
+ _, args_selected = torch.min(distances_to_quantized, dim=-1)
200
+ quantized = self.codebook[args_selected, :]
201
+ args_selected = args_selected.view(-1)
202
+
203
+ # Update weights
204
+ self.index_add_one_weights(args_selected)
205
+
206
+ distances_to_quantized = distances_to_quantized.view(
207
+ batch_size * num_agents, num_vq
208
+ )
209
+
210
+ # Resample useless latent vectors
211
+ random_latents = self.continuous_latent.view(
212
+ batch_size * num_agents, self.codebook.shape[-1]
213
+ )[torch.randint(batch_size * num_agents, (num_vq,))]
214
+ codebook_weights = self.get_weights()
215
+ total_samples = codebook_weights.sum()
216
+ # TODO: The value 100 is arbitrary, should it be a parameter?
217
+ # The uselessness of a codebook vector is defined by the number of times it has been sampled
218
+ # if it has been sampled less than 1% of the time, it is pushed towards a random continuous latent sample
219
+ # this prevents the codebook from being dominated by a few vectors
220
+ self.uselessness = (
221
+ (
222
+ torch.where(
223
+ (codebook_weights < total_samples / (100 * num_vq)).unsqueeze(-1),
224
+ random_latents.detach() - self.codebook,
225
+ torch.zeros_like(self.codebook),
226
+ ).abs()
227
+ + 1
228
+ )
229
+ .log()
230
+ .sum(-1)
231
+ .mean()
232
+ )
233
+ # TODO: The value 1e6 is arbitrary, should it be a parameter?
234
+ if total_samples > 1e6 * num_vq:
235
+ # Flush the codebook weights when the number of samples is too high
236
+ # This prevents the codebook from being dominated by its history
237
+ # if a few vectors were visited a lot and also prevents overflows
238
+ self.flush_weights()
239
+
240
+ # commit_loss = (self.continuous_latent - quantized.detach()).square().clamp_min(self.distance_threshold).sum(-1).mean()
241
+
242
+ self.quantization_loss = (
243
+ (self.continuous_latent - quantized).square().sum(-1).mean()
244
+ )
245
+
246
+ quantized = (
247
+ quantized.detach()
248
+ + self.continuous_latent
249
+ - self.continuous_latent.detach()
250
+ )
251
+
252
+ self.latent_diversity = (
253
+ (self.continuous_latent[None, ...] - self.continuous_latent[:, None, ...])
254
+ .square()
255
+ .sum(-1)
256
+ .mean()
257
+ )
258
+
259
+ return quantized, torch.ones_like(quantized[..., 0]) / num_vq
260
+
261
+ def kl_loss(
262
+ self,
263
+ other: "ClassifiedLatentDistribution",
264
+ threshold: float = 0,
265
+ mask_z: Optional[torch.Tensor] = None,
266
+ ) -> torch.Tensor:
267
+ """Compute the cross entropy between two latent distributions."""
268
+ assert type(other) == ClassifiedLatentDistribution
269
+ min_logits = -10
270
+ max_logits = 10
271
+ pred_log = other.logits.clamp(min_logits, max_logits).log_softmax(-1)
272
+ self_pred = self.soft_one_hot
273
+ self.accuracy = (self_pred.argmax(-1) == other.logits.argmax(-1)).float().mean()
274
+ return -2 * (pred_log * self_pred).sum(-1).mean()
275
+
276
+ def sampling_loss(self) -> torch.Tensor:
277
+ if self.quantization_loss is None:
278
+ self.sample()
279
+ return 0.5 * (
280
+ self.quantization_loss + self.uselessness + 0.001 * self.latent_diversity
281
+ )
282
+
283
+ def average(
284
+ self, other: "QuantizedLatentDistribution", weight_other: torch.Tensor
285
+ ) -> "QuantizedLatentDistribution":
286
+ raise NotImplementedError(
287
+ "Average is not implemented for QuantizedLatentDistribution"
288
+ )
289
+
290
+ def log_dict(self, type: str) -> dict:
291
+ log_dict = {
292
+ f"latent/{type}/quantization_loss": self.quantization_loss,
293
+ f"latent/{type}/uselessness": self.uselessness,
294
+ f"latent/{type}/latent_diversity": self.latent_diversity,
295
+ f"latent/{type}/codebook_abs_mean": self.codebook.abs().mean(),
296
+ f"latent/{type}/codebook_std": self.codebook.std(),
297
+ f"latent/{type}/latent_abs_mean": self.continuous_latent.abs().mean(),
298
+ f"latent/{type}/latent_std": self.continuous_latent.std(),
299
+ }
300
+ if self.accuracy is not None:
301
+ log_dict[f"latent/{type}/accuracy"] = self.accuracy
302
+ return log_dict
303
+
304
+
305
+ class ClassifiedLatentDistribution(AbstractLatentDistribution):
306
+ """Classified latent distribution.
307
+ It is defined with a codebook of quantized latents and a probability distribution over the codebook elements.
308
+
309
+ Args:
310
+ logits : Logits of shape (some_shape, num_embeddings)
311
+ codebook : Codebook of shape (num_embeddings, latent_dim)
312
+ """
313
+
314
+ def __init__(self, logits: torch.Tensor, codebook: torch.Tensor):
315
+ super().__init__()
316
+ self.register_buffer("logits", logits, persistent=False)
317
+ self.register_buffer("codebook", codebook, persistent=False)
318
+
319
+ def sample(
320
+ self, n_samples: int = 0, replacement: bool = True, *args, **kwargs
321
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
322
+ batch_size, num_agents, num_vq = self.logits.shape
323
+ squeeze_out = False
324
+ if n_samples == 0:
325
+ squeeze_out = True
326
+ n_samples = 1
327
+ elif n_samples > self.codebook.shape[0]:
328
+ warnings.warn(
329
+ f"Requested {n_samples} samples but only {self.codebook.shape[0]} are available in the descrete latent space. Switching to replacement=True to support it."
330
+ )
331
+ replacement = True
332
+
333
+ if self.training:
334
+ # TODO: should we make the temperature a parameter?
335
+ all_weights, indices = relaxed_one_hot_categorical_without_replacement(
336
+ logits=self.logits, temperature=1, num_samples=n_samples
337
+ )
338
+ selected_latents = self.codebook[indices, :]
339
+ # Cumulative mask of indices that have been sampled in order of probability
340
+ mask_selection = torch.nn.functional.one_hot(indices, num_vq).cumsum(-2)
341
+ mask_selection[..., 1:, :] = mask_selection[..., :-1, :]
342
+ mask_selection[..., 0, :] = 0.0
343
+ # Remove the probability of previous samples to account for sampling without replacement
344
+ masked_weights = all_weights.unsqueeze(-2) * (1 - mask_selection.float())
345
+ # Renormalize the probabilities to sum to 1
346
+ masked_weights = masked_weights / masked_weights.sum(-1, keepdim=True)
347
+
348
+ latent_samples = (
349
+ masked_weights.unsqueeze(-1)
350
+ * self.codebook[None, None, None, ...].detach()
351
+ ).sum(-2)
352
+ latent_samples = (
353
+ selected_latents.detach() + latent_samples - latent_samples.detach()
354
+ )
355
+ probs = torch.gather(self.logits.softmax(-1), -1, indices)
356
+ else:
357
+ probs = self.logits.softmax(-1)
358
+ samples = torch.multinomial(
359
+ probs.view(batch_size * num_agents, num_vq),
360
+ n_samples,
361
+ replacement=replacement,
362
+ )
363
+ latent_samples = self.codebook[samples]
364
+ probs = torch.gather(
365
+ probs, -1, samples.view(batch_size, num_agents, num_vq)
366
+ )
367
+
368
+ if squeeze_out:
369
+ latent_samples = latent_samples.view(
370
+ batch_size, num_agents, self.codebook.shape[-1]
371
+ )
372
+ else:
373
+ latent_samples = latent_samples.view(
374
+ batch_size, num_agents, n_samples, self.codebook.shape[-1]
375
+ )
376
+ return latent_samples, probs
377
+
378
+ def kl_loss(
379
+ self,
380
+ other: "ClassifiedLatentDistribution",
381
+ threshold: float = 0,
382
+ mask_z: Optional[torch.Tensor] = None,
383
+ ) -> torch.Tensor:
384
+ """Compute the cross entropy between two latent distributions. Self being the reference distribution and other the distribution to compare."""
385
+ assert type(other) == ClassifiedLatentDistribution
386
+ min_logits = -10
387
+ max_logits = 10
388
+ pred_log = other.logits.clamp(min_logits, max_logits).log_softmax(-1)
389
+ self_pred = (
390
+ (0.5 * (self.logits.detach() + self.logits))
391
+ .clamp(min_logits, max_logits)
392
+ .softmax(-1)
393
+ )
394
+ return -2 * (pred_log * self_pred).sum(-1).mean()
395
+
396
+ def sampling_loss(self) -> torch.Tensor:
397
+ return torch.zeros(1, device=self.logits.device)
398
+
399
+ def average(
400
+ self, other: "ClassifiedLatentDistribution", weight_other: torch.Tensor
401
+ ) -> "ClassifiedLatentDistribution":
402
+ assert type(other) == ClassifiedLatentDistribution
403
+ assert (self.codebook == other.codebook).all()
404
+ return ClassifiedLatentDistribution(
405
+ (
406
+ self.logits.exp() * (1 - weight_other)
407
+ + other.logits.exp() * weight_other
408
+ ).log(),
409
+ self.codebook,
410
+ )
411
+
412
+ def log_dict(self, type: str) -> dict:
413
+ max_probs, _ = self.logits.softmax(-1).max(-1)
414
+ return {
415
+ f"latent/{type}/codebook_abs_mean": self.codebook.abs().mean(),
416
+ f"latent/{type}/codebook_std": self.codebook.std(),
417
+ f"latent/{type}/class_max_mean": max_probs.mean(),
418
+ f"latent/{type}/class_max_std": max_probs.std(),
419
+ }
420
+
421
+
422
+ class QuantizedDistributionCreator(nn.Module):
423
+ """Creates a distribution from a latent vector."""
424
+
425
+ def __init__(
426
+ self,
427
+ latent_dim: int,
428
+ num_embeddings: int,
429
+ ):
430
+ super().__init__()
431
+ self.latent_dim = latent_dim
432
+ self.num_embeddings = num_embeddings
433
+ self.codebook = nn.Parameter(torch.randn(num_embeddings, latent_dim))
434
+ self.register_buffer(
435
+ "codebook_weights",
436
+ torch.ones(num_embeddings, requires_grad=False),
437
+ persistent=False,
438
+ )
439
+
440
+ def _flush_codebook_weights(self):
441
+ self.codebook_weights = torch.ones_like(self.codebook_weights)
442
+
443
+ def _get_codebook_weights(self):
444
+ return self.codebook_weights
445
+
446
+ def _index_add_one_codebook_weight(self, indices: torch.Tensor):
447
+ self.codebook_weights = self.codebook_weights.index_add(
448
+ 0,
449
+ indices.flatten(),
450
+ torch.ones_like(self.codebook_weights[indices]),
451
+ )
452
+
453
+ def forward(self, latent: torch.Tensor) -> AbstractLatentDistribution:
454
+ if latent.shape[-1] == self.latent_dim:
455
+ return QuantizedLatentDistribution(
456
+ latent,
457
+ self.codebook,
458
+ self._flush_codebook_weights,
459
+ self._get_codebook_weights,
460
+ self._index_add_one_codebook_weight,
461
+ )
462
+ elif latent.shape[-1] == self.num_embeddings:
463
+ return ClassifiedLatentDistribution(
464
+ latent,
465
+ self.codebook,
466
+ )
467
+ else:
468
+ raise ValueError(f"Latent vector has wrong dimension: {latent.shape[-1]}")
risk_biased/models/map_encoder.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from risk_biased.models.nn_blocks import (
4
+ SequenceEncoderLSTM,
5
+ SequenceEncoderMLP,
6
+ SequenceEncoderMaskedLSTM,
7
+ )
8
+
9
+ from risk_biased.models.cvae_params import CVAEParams
10
+ from risk_biased.models.mlp import MLP
11
+
12
+
13
+ class MapEncoderNN(nn.Module):
14
+ """MLP encoder neural network that encodes map objects.
15
+
16
+ Args:
17
+ params: dataclass defining the necessary parameters
18
+ """
19
+
20
+ def __init__(self, params: CVAEParams) -> None:
21
+ super().__init__()
22
+ self._encoder = SequenceEncoderMLP(
23
+ params.map_state_dim,
24
+ params.hidden_dim,
25
+ params.num_hidden_layers,
26
+ params.max_size_lane,
27
+ params.is_mlp_residual,
28
+ )
29
+
30
+ def forward(self, map, mask_map):
31
+ """Forward function encoding map object sequences of features into object features.
32
+
33
+ Args:
34
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
35
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
36
+ """
37
+ encoded_map = self._encoder(map, mask_map)
38
+ return encoded_map
risk_biased/models/mlp.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class MLP(nn.Module):
8
+ """Basic MLP implementation with FC layers and ReLU activation.
9
+
10
+ Args:
11
+ input_dim : dimension of the input variable
12
+ output_dim : dimension of the output variable
13
+ h_dim : dimension of a hidden layer of MLP
14
+ num_h_layers : number of hidden layers in MLP
15
+ add_residual : set to True to add input to output (res-net) set to False to have pure MLP
16
+
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ input_dim: int,
22
+ output_dim: int,
23
+ h_dim: int,
24
+ num_h_layers: int,
25
+ add_residual: bool,
26
+ ) -> None:
27
+ super().__init__()
28
+ self._input_dim = input_dim
29
+ self._output_dim = output_dim
30
+ self._h_dim = h_dim
31
+ self._num_h_layers = num_h_layers
32
+
33
+ layers = OrderedDict()
34
+ if num_h_layers > 0:
35
+ layers["fc_0"] = nn.Linear(input_dim, h_dim)
36
+ layers["relu_0"] = nn.ReLU()
37
+ else:
38
+ h_dim = input_dim
39
+ for ii in range(1, num_h_layers):
40
+ layers["fc_{}".format(ii)] = nn.Linear(h_dim, h_dim)
41
+ layers["relu_{}".format(ii)] = nn.ReLU()
42
+ layers["fc_{}".format(num_h_layers)] = nn.Linear(h_dim, self._output_dim)
43
+
44
+ self.mlp = nn.Sequential(layers)
45
+ if add_residual:
46
+ self.residual_layer = nn.Linear(input_dim, output_dim)
47
+ else:
48
+ self.residual_layer = lambda x: 0
49
+ self._layer_norm = nn.LayerNorm(output_dim)
50
+
51
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
52
+ """Forward function for MLP
53
+
54
+ Args:
55
+ input (torch.Tensor): (batch_size, input_dim) tensor
56
+
57
+ Returns:
58
+ torch.Tensor: (batch_size, output_dim) tensor
59
+ """
60
+ return self._layer_norm(self.mlp(input) + self.residual_layer(input))
risk_biased/models/multi_head_attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation from https://einops.rocks/pytorch-examples.html slightly changed
2
+
3
+
4
+ import math
5
+
6
+ from typing import Tuple
7
+ import torch
8
+ from torch import nn
9
+ from einops import rearrange, repeat
10
+
11
+
12
+ class MultiHeadAttention(nn.Module):
13
+ """
14
+ This is a slightly modified version of the original implementation from https://einops.rocks/pytorch-examples.html of multihead attention.
15
+ It keeps the original dimension division per head and masks the attention matrix before and after the softmax to support full row masking.
16
+
17
+ Args:
18
+ d_model: the input feature dimension of the model
19
+ n_head: the number of heads in the multihead attention
20
+ d_k: the dimension of the key and query in the multihead attention
21
+ d_v: the dimension of the value in the multihead attention
22
+ """
23
+
24
+ def __init__(self, d_model: int, n_head: int, d_k: torch.Tensor, d_v: torch.Tensor):
25
+ super().__init__()
26
+ self.n_head = n_head
27
+
28
+ self.w_qs = nn.Linear(d_model, int(d_k / n_head) * n_head)
29
+ self.w_ks = nn.Linear(d_model, int(d_k / n_head) * n_head)
30
+ self.w_vs = nn.Linear(d_model, int(d_v / n_head) * n_head)
31
+ self.w_rs = nn.Linear(d_model, int(d_v / n_head) * n_head)
32
+
33
+ nn.init.normal_(self.w_qs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_k)))
34
+ nn.init.normal_(self.w_ks.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_k)))
35
+ nn.init.normal_(self.w_vs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_v)))
36
+ nn.init.normal_(self.w_rs.weight, mean=0, std=math.sqrt(2.0 / (d_model + d_v)))
37
+
38
+ self.fc = nn.Linear(int(d_v / n_head) * n_head, d_model)
39
+ nn.init.xavier_normal_(self.fc.weight)
40
+ self.layer_norm = nn.LayerNorm(d_model)
41
+
42
+ def forward(
43
+ self,
44
+ q: torch.Tensor,
45
+ k: torch.Tensor,
46
+ v: torch.Tensor,
47
+ mask: torch.Tensor = None,
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ """
50
+ Compute the masked multi-head attention given the query, key and value tensors.
51
+
52
+ Args:
53
+ q: the query tensor of shape [batch_size, number_of_agents, d_model]
54
+ k: the key tensor of shape [batch_size, number_of_objects, d_model]
55
+ v: the value tensor of shape [batch_size, number_of_objects, d_model]
56
+ mask: the mask tensor of shape [batch_size, number_of_agents, number_of_objects]
57
+
58
+ Returns:
59
+ [
60
+ The attention output tensor of shape [batch_size, number_of_agents, d_model],
61
+ The attention matrix of shape [batch_size, number_of_agents, number_of_objects]
62
+ ]
63
+ """
64
+ residual = q.clone()
65
+ r = self.w_rs(q)
66
+ q = rearrange(self.w_qs(q), "b a (head k) -> head b a k", head=self.n_head)
67
+ k = rearrange(self.w_ks(k), "b o (head k) -> head b o k", head=self.n_head)
68
+ v = rearrange(self.w_vs(v), "b o (head v) -> head b o v", head=self.n_head)
69
+ attn = torch.einsum("hbak,hbok->hbao", [q, k]) / math.sqrt(q.shape[-1])
70
+ if mask is not None:
71
+ # b: batch, a: agent, o: object, h: head
72
+ mask = repeat(mask, "b a o -> h b a o", h=self.n_head)
73
+ attn = attn.masked_fill(mask == 0, -math.inf)
74
+ attn = torch.softmax(attn, dim=3)
75
+ # Here we need to mask again because some lines might be all -inf in the softmax which gives Nan...
76
+ attn = attn.masked_fill(mask == 0, 0)
77
+ output = torch.einsum("hbao,hbov->hbav", [attn, v])
78
+ output = rearrange(output, "head b a v -> b a (head v)")
79
+ output = self.fc(output * r)
80
+ output = self.layer_norm(output + residual)
81
+ return output, attn
risk_biased/models/nn_blocks.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from einops.layers.torch import Rearrange
2
+ from einops import rearrange, repeat
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from risk_biased.models.multi_head_attention import MultiHeadAttention
7
+ from risk_biased.models.context_gating import ContextGating
8
+ from risk_biased.models.mlp import MLP
9
+
10
+
11
+ class SequenceEncoderMaskedLSTM(nn.Module):
12
+ """MLP followed with a masked LSTM implementation with one layer.
13
+
14
+ Args:
15
+ input_dim : dimension of the input variable
16
+ h_dim : dimension of a hidden layer of MLP
17
+ """
18
+
19
+ def __init__(self, input_dim: int, h_dim: int) -> None:
20
+ super().__init__()
21
+ self._group_objects = Rearrange("b o ... -> (b o) ...")
22
+ self._embed = nn.Linear(in_features=input_dim, out_features=h_dim)
23
+ self._lstm = nn.LSTMCell(
24
+ input_size=h_dim, hidden_size=h_dim
25
+ ) # expects(batch,seq,features)
26
+ self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
27
+ self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
28
+
29
+ def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
30
+ """Forward function for MapEncoder
31
+
32
+ Args:
33
+ input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
34
+ mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
35
+
36
+ Returns:
37
+ torch.Tensor: (batch_size, num_objects, output_dim) tensor
38
+ """
39
+
40
+ batch_size, num_objects, seq_len, _ = input.shape
41
+ split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects)
42
+
43
+ input = self._group_objects(input)
44
+ mask_input = self._group_objects(mask_input)
45
+ embedded_input = self._embed(input)
46
+
47
+ # One to many encoding of the input sequence with masking for missing points
48
+ mask_input = mask_input.float()
49
+ h = mask_input[:, 0, None] * embedded_input[:, 0, :] + (
50
+ 1 - mask_input[:, 0, None]
51
+ ) * repeat(self.h0, "b f -> (size b) f", size=batch_size * num_objects)
52
+ c = repeat(self.c0, "b f -> (size b) f", size=batch_size * num_objects)
53
+ for i in range(seq_len):
54
+ new_input = (
55
+ mask_input[:, i, None] * embedded_input[:, i, :]
56
+ + (1 - mask_input[:, i, None]) * h
57
+ )
58
+ h, c = self._lstm(new_input, (h, c))
59
+ return split_objects(h)
60
+
61
+
62
+ class SequenceEncoderLSTM(nn.Module):
63
+ """MLP followed with an LSTM with one layer.
64
+
65
+ Args:
66
+ input_dim : dimension of the input variable
67
+ h_dim : dimension of a hidden layer of MLP
68
+ """
69
+
70
+ def __init__(self, input_dim: int, h_dim: int) -> None:
71
+ super().__init__()
72
+ self._group_objects = Rearrange("b o ... -> (b o) ...")
73
+ self._embed = nn.Linear(in_features=input_dim, out_features=h_dim)
74
+ self._lstm = nn.LSTM(
75
+ input_size=h_dim,
76
+ hidden_size=h_dim,
77
+ batch_first=True,
78
+ ) # expects(batch,seq,features)
79
+ self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
80
+ self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
81
+
82
+ def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
83
+ """Forward function for MapEncoder
84
+
85
+ Args:
86
+ input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
87
+ mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
88
+
89
+ Returns:
90
+ torch.Tensor: (batch_size, num_objects, output_dim) tensor
91
+ """
92
+
93
+ batch_size, num_objects, seq_len, _ = input.shape
94
+ split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects)
95
+
96
+ input = self._group_objects(input)
97
+ mask_input = self._group_objects(mask_input)
98
+ embedded_input = self._embed(input)
99
+
100
+ # One to many encoding of the input sequence with masking for missing points
101
+ mask_input = mask_input.float()
102
+ h = (
103
+ mask_input[:, 0, None] * embedded_input[:, 0, :]
104
+ + (1 - mask_input[:, 0, None])
105
+ * repeat(
106
+ self.h0, "one f -> one size f", size=batch_size * num_objects
107
+ ).contiguous()
108
+ )
109
+ c = repeat(
110
+ self.c0, "one f -> one size f", size=batch_size * num_objects
111
+ ).contiguous()
112
+ _, (h, _) = self._lstm(embedded_input, (h, c))
113
+ # for i in range(seq_len):
114
+ # new_input = (
115
+ # mask_input[:, i, None] * embedded_input[:, i, :]
116
+ # + (1 - mask_input[:, i, None]) * h
117
+ # )
118
+ # h, c = self._lstm(new_input, (h, c))
119
+ return split_objects(h.squeeze(0))
120
+
121
+
122
+ class SequenceEncoderMLP(nn.Module):
123
+ """MLP implementation.
124
+
125
+ Args:
126
+ input_dim : dimension of the input variable
127
+ h_dim : dimension of a hidden layer of MLP
128
+ num_layers: number of layers to use in the MLP
129
+ sequence_length: dimension of the input sequence
130
+ is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
131
+ """
132
+
133
+ def __init__(
134
+ self,
135
+ input_dim: int,
136
+ h_dim: int,
137
+ num_layers: int,
138
+ sequence_length: int,
139
+ is_mlp_residual: bool,
140
+ ) -> None:
141
+ super().__init__()
142
+ self._mlp = MLP(
143
+ input_dim * sequence_length, h_dim, h_dim, num_layers, is_mlp_residual
144
+ )
145
+
146
+ def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
147
+ """Forward function for MapEncoder
148
+
149
+ Args:
150
+ input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
151
+ mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
152
+
153
+ Returns:
154
+ torch.Tensor: (batch_size, num_objects, output_dim) tensor
155
+ """
156
+
157
+ batch_size, num_objects, _, _ = input.shape
158
+ input = input * mask_input.unsqueeze(-1)
159
+ h = rearrange(input, "b o s f -> (b o) (s f)")
160
+ mask_input = rearrange(mask_input, "b o s -> (b o) s")
161
+ if h.shape[-1] == 0:
162
+ h = h.view(batch_size, 0, h.shape[0])
163
+ else:
164
+ h = self._mlp(h)
165
+ h = rearrange(h, "(b o) f -> b o f", b=batch_size, o=num_objects)
166
+
167
+ return h
168
+
169
+
170
+ class SequenceDecoderLSTM(nn.Module):
171
+ """A one to many LSTM implementation with one layer.
172
+
173
+ Args:
174
+ h_dim : dimension of a hidden layer
175
+ """
176
+
177
+ def __init__(self, h_dim: int) -> None:
178
+ super().__init__()
179
+ self._group_objects = Rearrange("b o f -> (b o) f")
180
+ self._lstm = nn.LSTM(input_size=h_dim, hidden_size=h_dim)
181
+ self._out_layer = nn.Linear(in_features=h_dim, out_features=h_dim)
182
+ self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
183
+
184
+ def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor:
185
+ """Forward function for MapEncoder
186
+
187
+ Args:
188
+ input (torch.Tensor): (batch_size, num_objects, input_dim) tensor
189
+ sequence_length: output sequence length to create
190
+ Returns:
191
+ torch.Tensor: (batch_size, num_objects, output_dim) tensor
192
+ """
193
+
194
+ batch_size, num_objects, _ = input.shape
195
+
196
+ h = repeat(input, "b o f -> one (b o) f", one=1).contiguous()
197
+ c = repeat(
198
+ self.c0, "one f -> one size f", size=batch_size * num_objects
199
+ ).contiguous()
200
+ seq_h = repeat(h, "one b f -> (one t) b f", t=sequence_length).contiguous()
201
+ h, (_, _) = self._lstm(seq_h, (h, c))
202
+ h = rearrange(h, "t (b o) f -> b o t f", b=batch_size, o=num_objects)
203
+ return self._out_layer(h)
204
+
205
+
206
+ class SequenceDecoderMLP(nn.Module):
207
+ """A one to many MLP implementation.
208
+
209
+ Args:
210
+ h_dim : dimension of a hidden layer
211
+ num_layers: number of layers to use in the MLP
212
+ sequence_length: output sequence length to return
213
+ is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
214
+ """
215
+
216
+ def __init__(
217
+ self, h_dim: int, num_layers: int, sequence_length: int, is_mlp_residual: bool
218
+ ) -> None:
219
+ super().__init__()
220
+ self._mlp = MLP(
221
+ h_dim, h_dim * sequence_length, h_dim, num_layers, is_mlp_residual
222
+ )
223
+
224
+ def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor:
225
+ """Forward function for MapEncoder
226
+
227
+ Args:
228
+ input (torch.Tensor): (batch_size, num_objects, input_dim) tensor
229
+ sequence_length: output sequence length to create
230
+ Returns:
231
+ torch.Tensor: (batch_size, num_objects, output_dim) tensor
232
+ """
233
+
234
+ batch_size, num_objects, _ = input.shape
235
+
236
+ h = rearrange(input, "b o f -> (b o) f")
237
+ h = self._mlp(h)
238
+ h = rearrange(
239
+ h, "(b o) (s f) -> b o s f", b=batch_size, o=num_objects, s=sequence_length
240
+ )
241
+ return h
242
+
243
+
244
+ class AttentionBlock(nn.Module):
245
+ """Block performing agent-map cross attention->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm
246
+ Args:
247
+ hidden_dim: feature dimension
248
+ num_attention_heads: number of attention heads to use
249
+ """
250
+
251
+ def __init__(self, hidden_dim: int, num_attention_heads: int):
252
+ super().__init__()
253
+ self._num_attention_heads = num_attention_heads
254
+ self._agent_map_attention = MultiHeadAttention(
255
+ hidden_dim, num_attention_heads, hidden_dim, hidden_dim
256
+ )
257
+ self._lin1 = nn.Linear(hidden_dim, hidden_dim)
258
+ self._layer_norm1 = nn.LayerNorm(hidden_dim)
259
+ self._agent_agent_attention = MultiHeadAttention(
260
+ hidden_dim, num_attention_heads, hidden_dim, hidden_dim
261
+ )
262
+ self._lin2 = nn.Linear(hidden_dim, hidden_dim)
263
+ self._layer_norm2 = nn.LayerNorm(hidden_dim)
264
+ self._activation = nn.ReLU()
265
+
266
+ def forward(
267
+ self,
268
+ encoded_agents: torch.Tensor,
269
+ mask_agents: torch.Tensor,
270
+ encoded_absolute_agents: torch.Tensor,
271
+ encoded_map: torch.Tensor,
272
+ mask_map: torch.Tensor,
273
+ ) -> torch.Tensor:
274
+ """Forward function of the block, returning only the output (no attention matrix)
275
+ Args:
276
+ encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
277
+ mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
278
+ encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
279
+ encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
280
+ mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
281
+ """
282
+
283
+ # Check if map_info is available. If not, don't compute cross-attention with it
284
+ if mask_map.any():
285
+ mask_agent_map = torch.einsum("ba,bo->bao", mask_agents, mask_map)
286
+ h, _ = self._agent_map_attention(
287
+ encoded_agents + encoded_absolute_agents,
288
+ encoded_map,
289
+ encoded_map,
290
+ mask=mask_agent_map,
291
+ )
292
+ h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
293
+ h = torch.sigmoid(self._lin1(h))
294
+ h = self._layer_norm1(encoded_agents + h)
295
+ else:
296
+ h = self._layer_norm1(encoded_agents)
297
+
298
+ h_res = h.clone()
299
+ agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents)
300
+ h = h + encoded_absolute_agents
301
+ h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask)
302
+ h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
303
+ h = self._activation(self._lin2(h))
304
+ h = self._layer_norm2(h_res + h)
305
+ return h
306
+
307
+
308
+ class CG_block(nn.Module):
309
+ """Block performing context gating agent-map
310
+ Args:
311
+ hidden_dim: feature dimension
312
+ dim_expansion: multiplicative factor on the hidden dimension for the global context representation
313
+ num_layers: number of layers to use in the MLP for context encoding
314
+ is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
315
+ """
316
+
317
+ def __init__(
318
+ self,
319
+ hidden_dim: int,
320
+ dim_expansion: int,
321
+ num_layers: int,
322
+ is_mlp_residual: bool,
323
+ ):
324
+ super().__init__()
325
+ self._agent_map = ContextGating(
326
+ hidden_dim,
327
+ hidden_dim * dim_expansion,
328
+ num_layers=num_layers,
329
+ is_mlp_residual=is_mlp_residual,
330
+ )
331
+ self._lin1 = nn.Linear(hidden_dim, hidden_dim)
332
+ self._layer_norm1 = nn.LayerNorm(hidden_dim)
333
+ self._agent_agent = ContextGating(
334
+ hidden_dim, hidden_dim * dim_expansion, num_layers, is_mlp_residual
335
+ )
336
+ self._lin2 = nn.Linear(hidden_dim, hidden_dim)
337
+ self._activation = nn.ReLU()
338
+
339
+ def forward(
340
+ self,
341
+ encoded_agents: torch.Tensor,
342
+ mask_agents: torch.Tensor,
343
+ encoded_absolute_agents: torch.Tensor,
344
+ encoded_map: torch.Tensor,
345
+ mask_map: torch.Tensor,
346
+ global_context: torch.Tensor,
347
+ ) -> torch.Tensor:
348
+ """Forward function of the block, returning the output and global context
349
+ Args:
350
+ encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
351
+ mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
352
+ encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
353
+ encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
354
+ mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
355
+ global_context: (batch_size, dim_context) tensor representing the global context
356
+ """
357
+
358
+ # Check if map_info is available. If not, don't compute cross-interaction with it
359
+ if mask_map.any():
360
+ s, global_context = self._agent_map(
361
+ encoded_agents + encoded_absolute_agents, encoded_map, global_context
362
+ )
363
+ s = s * mask_agents.unsqueeze(-1)
364
+ s = self._activation(self._lin1(s))
365
+ s = self._layer_norm1(encoded_agents + s)
366
+ else:
367
+ s = self._layer_norm1(encoded_agents)
368
+
369
+ s = s + encoded_absolute_agents
370
+ s, global_context = self._agent_agent(s, s, global_context)
371
+ s = s * mask_agents.unsqueeze(-1)
372
+ s = self._lin2(s)
373
+ return s, global_context
374
+
375
+
376
+ class HybridBlock(nn.Module):
377
+ """Block performing agent-map cross context_gating->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm
378
+ Args:
379
+ hidden_dim: feature dimension
380
+ num_attention_heads: number of attention heads to use
381
+ dim_expansion: multiplicative factor on the hidden dimension for the global context representation
382
+ num_layers: number of layers to use in the MLP for context encoding
383
+ is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
384
+ """
385
+
386
+ def __init__(
387
+ self,
388
+ hidden_dim: int,
389
+ num_attention_heads: int,
390
+ dim_expansion: int,
391
+ num_layers: int,
392
+ is_mlp_residual: bool,
393
+ ):
394
+ super().__init__()
395
+ self._num_attention_heads = num_attention_heads
396
+ self._agent_map_cg = ContextGating(
397
+ hidden_dim,
398
+ hidden_dim * dim_expansion,
399
+ num_layers=num_layers,
400
+ is_mlp_residual=is_mlp_residual,
401
+ )
402
+ self._lin1 = nn.Linear(hidden_dim, hidden_dim)
403
+ self._layer_norm1 = nn.LayerNorm(hidden_dim)
404
+ self._agent_agent_attention = MultiHeadAttention(
405
+ hidden_dim, num_attention_heads, hidden_dim, hidden_dim
406
+ )
407
+ self._lin2 = nn.Linear(hidden_dim, hidden_dim)
408
+ self._layer_norm2 = nn.LayerNorm(hidden_dim)
409
+ self._activation = nn.ReLU()
410
+
411
+ def forward(
412
+ self,
413
+ encoded_agents: torch.Tensor,
414
+ mask_agents: torch.Tensor,
415
+ encoded_absolute_agents: torch.Tensor,
416
+ encoded_map: torch.Tensor,
417
+ mask_map: torch.Tensor,
418
+ global_context: torch.Tensor,
419
+ ) -> torch.Tensor:
420
+ """Forward function of the block, returning the output and the context (no attention matrix)
421
+ Args:
422
+ encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
423
+ mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
424
+ encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
425
+ encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
426
+ mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
427
+ global_context: (batch_size, dim_context) tensor representing the global context
428
+ """
429
+
430
+ # Check if map_info is available. If not, don't compute cross-context gating with it
431
+ if mask_map.any():
432
+ # mask_agent_map = torch.logical_not(
433
+ # torch.einsum("ba,bo->bao", mask_agents, mask_map)
434
+ # )
435
+ h, global_context = self._agent_map_cg(
436
+ encoded_agents + encoded_absolute_agents, encoded_map, global_context
437
+ )
438
+ h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
439
+ h = self._activation(self._lin1(h))
440
+ h = self._layer_norm1(encoded_agents + h)
441
+ else:
442
+ h = self._layer_norm1(encoded_agents)
443
+
444
+ h_res = h.clone()
445
+ agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents)
446
+ h = h + encoded_absolute_agents
447
+ h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask)
448
+ h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
449
+ h = self._activation(self._lin2(h))
450
+ h = self._layer_norm2(h_res + h)
451
+ return h, global_context
452
+
453
+
454
+ class MCG(nn.Module):
455
+ """Multiple context encoding blocks
456
+ Args:
457
+ hidden_dim: feature dimension
458
+ dim_expansion: multiplicative factor on the hidden dimension for the global context representation
459
+ num_layers: number of layers to use in the MLP for context encoding
460
+ num_blocks: number of successive context encoding blocks to use in the module
461
+ is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
462
+ """
463
+
464
+ def __init__(
465
+ self,
466
+ hidden_dim: int,
467
+ dim_expansion: int,
468
+ num_layers: int,
469
+ num_blocks: int,
470
+ is_mlp_residual: bool,
471
+ ):
472
+ super().__init__()
473
+ self.initial_global_context = nn.parameter.Parameter(
474
+ torch.ones(1, hidden_dim * dim_expansion)
475
+ )
476
+ list_cg = []
477
+ for i in range(num_blocks):
478
+ list_cg.append(
479
+ CG_block(hidden_dim, dim_expansion, num_layers, is_mlp_residual)
480
+ )
481
+ self.mcg = nn.ModuleList(list_cg)
482
+
483
+ def forward(
484
+ self,
485
+ encoded_agents: torch.Tensor,
486
+ mask_agents: torch.Tensor,
487
+ encoded_absolute_agents: torch.Tensor,
488
+ encoded_map: torch.Tensor,
489
+ mask_map: torch.Tensor,
490
+ ) -> torch.Tensor:
491
+ """Forward function of the block, returning only the output (no context)
492
+ Args:
493
+ encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
494
+ mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
495
+ encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
496
+ encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
497
+ mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
498
+ """
499
+ s = encoded_agents
500
+ c = self.initial_global_context
501
+ sum_s = s
502
+ sum_c = c
503
+ for i, cg in enumerate(self.mcg):
504
+ s_new, c_new = cg(
505
+ s, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c
506
+ )
507
+ sum_s = sum_s + s_new
508
+ sum_c = sum_c + c_new
509
+ s = (sum_s / (i + 2)).clone()
510
+ c = (sum_c / (i + 2)).clone()
511
+ return s
512
+
513
+
514
+ class MAB(nn.Module):
515
+ """Multiple Attention Blocks
516
+ Args:
517
+ hidden_dim: feature dimension
518
+ num_attention_heads: number of attention heads to use
519
+ num_blocks: number of successive blocks to use in the module.
520
+ """
521
+
522
+ def __init__(
523
+ self,
524
+ hidden_dim: int,
525
+ num_attention_heads: int,
526
+ num_blocks: int,
527
+ ):
528
+ super().__init__()
529
+ list_attention = []
530
+ for i in range(num_blocks):
531
+ list_attention.append(AttentionBlock(hidden_dim, num_attention_heads))
532
+ self.attention_blocks = nn.ModuleList(list_attention)
533
+
534
+ def forward(
535
+ self,
536
+ encoded_agents: torch.Tensor,
537
+ mask_agents: torch.Tensor,
538
+ encoded_absolute_agents: torch.Tensor,
539
+ encoded_map: torch.Tensor,
540
+ mask_map: torch.Tensor,
541
+ ) -> torch.Tensor:
542
+ """Forward function of the block, returning only the output (no attention matrix)
543
+ Args:
544
+ encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
545
+ mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
546
+ encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
547
+ encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
548
+ mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
549
+ """
550
+ h = encoded_agents
551
+ sum_h = h
552
+ for i, attention in enumerate(self.attention_blocks):
553
+ h_new = attention(
554
+ h, mask_agents, encoded_absolute_agents, encoded_map, mask_map
555
+ )
556
+ sum_h = sum_h + h_new
557
+ h = (sum_h / (i + 2)).clone()
558
+ return h
559
+
560
+
561
+ class MHB(nn.Module):
562
+ """Multiple Hybrid Blocks
563
+ Args:
564
+ hidden_dim: feature dimension
565
+ num_attention_heads: number of attention heads to use
566
+ dim_expansion: multiplicative factor on the hidden dimension for the global context representation
567
+ num_layers: number of layers to use in the MLP for context encoding
568
+ num_blocks: number of successive blocks to use in the module.
569
+ is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
570
+ """
571
+
572
+ def __init__(
573
+ self,
574
+ hidden_dim: int,
575
+ num_attention_heads: int,
576
+ dim_expansion: int,
577
+ num_layers: int,
578
+ num_blocks: int,
579
+ is_mlp_residual: bool,
580
+ ):
581
+ super().__init__()
582
+ self.initial_global_context = nn.parameter.Parameter(
583
+ torch.ones(1, hidden_dim * dim_expansion)
584
+ )
585
+ list_hb = []
586
+ for i in range(num_blocks):
587
+ list_hb.append(
588
+ HybridBlock(
589
+ hidden_dim,
590
+ num_attention_heads,
591
+ dim_expansion,
592
+ num_layers,
593
+ is_mlp_residual,
594
+ )
595
+ )
596
+ self.hybrid_blocks = nn.ModuleList(list_hb)
597
+
598
+ def forward(
599
+ self,
600
+ encoded_agents: torch.Tensor,
601
+ mask_agents: torch.Tensor,
602
+ encoded_absolute_agents: torch.Tensor,
603
+ encoded_map: torch.Tensor,
604
+ mask_map: torch.Tensor,
605
+ ) -> torch.Tensor:
606
+ """Forward function of the block, returning only the output (no attention matrix nor context)
607
+ Args:
608
+ encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
609
+ mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
610
+ encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
611
+ encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
612
+ mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
613
+ """
614
+ sum_h = encoded_agents
615
+ sum_c = self.initial_global_context
616
+ h = encoded_agents
617
+ c = self.initial_global_context
618
+ for i, hb in enumerate(self.hybrid_blocks):
619
+ h_new, c_new = hb(
620
+ h, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c
621
+ )
622
+ sum_h = sum_h + h_new
623
+ sum_c = sum_c + c_new
624
+ h = (sum_h / (i + 2)).clone()
625
+ c = (sum_c / (i + 2)).clone()
626
+ return h
risk_biased/mpc_planner/__init__.py ADDED
File without changes
risk_biased/mpc_planner/dynamics.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from risk_biased.utils.planner_utils import AbstractState, to_state
4
+
5
+
6
+ class PositionVelocityDoubleIntegrator:
7
+ """Deterministic discrete-time double-integrator dynamics, where state is
8
+ [position_x_m, position_y_m, velocity_x_m_s velocity_y_m_s] and control is
9
+ [acceleration_x_m_s2, acceleration_y_m_s2].
10
+
11
+ Args:
12
+ dt: time differential between two discrete timesteps in seconds
13
+ """
14
+
15
+ def __init__(self, dt: float):
16
+ self.dt = dt
17
+ self.control_dim = 2
18
+
19
+ def simulate(
20
+ self,
21
+ state_init: AbstractState,
22
+ control_input: torch.Tensor,
23
+ ) -> AbstractState:
24
+ """Euler-integrate dynamics from the initial position and the initial velocity given
25
+ an acceleration input
26
+
27
+ Args:
28
+ state_init: initial Markov state of the system
29
+ control_input: (num_agents, num_steps_future, 2) tensor of acceleration input
30
+
31
+ Returns:
32
+ (num_agents, num_steps_future, 5) tensor of simulated future Markov state
33
+ sequence
34
+ """
35
+ position_init, velocity_init = state_init.position, state_init.velocity
36
+
37
+ assert (
38
+ control_input.shape[-1] == self.control_dim
39
+ ), "invalid control input dimension"
40
+
41
+ velocity_future = velocity_init + self.dt * torch.cumsum(control_input, dim=-2)
42
+
43
+ position_future = position_init + self.dt * torch.cumsum(
44
+ velocity_future, dim=-2
45
+ )
46
+ state_future = to_state(
47
+ torch.cat((position_future, velocity_future), dim=-1), self.dt
48
+ )
49
+ return state_future
risk_biased/mpc_planner/planner.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Tuple
3
+
4
+ import torch
5
+ from mmcv import Config
6
+
7
+ from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor
8
+ from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
9
+ from risk_biased.mpc_planner.planner_cost import TrackingCostParams
10
+ from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams
11
+ from risk_biased.mpc_planner.planner_cost import TrackingCost
12
+ from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
13
+ from risk_biased.utils.planner_utils import AbstractState, to_state
14
+ from risk_biased.utils.risk import get_risk_estimator
15
+
16
+
17
+ @dataclass
18
+ class MPCPlannerParams:
19
+ """Dataclass for MPC-Planner Parameters
20
+
21
+ Args:
22
+ dt_s: discrete time interval in seconds that is used for planning
23
+ num_steps: number of time steps for which history of ego's and the other actor's
24
+ trajectories are stored
25
+ num_steps_future: number of time steps into the future for which ego's and the other actor's
26
+ trajectories are considered
27
+ acceleration_std_x_m_s2: Acceleration noise standard deviation (m/s^2) in x-direction that
28
+ is used to initialize the Cross Entropy solver
29
+ acceleration_std_y_m_s2: Acceleration noise standard deviation (m/s^2) in y-direction that
30
+ is used to initialize the Cross Entropy solver
31
+ risk_estimator_params: parameters for the Monte Carlo risk estimator used in the planner for
32
+ ego's control optimization
33
+ solver_params: parameters for the CrossEntropySolver
34
+ tracking_cost_params: parameters for the TrackingCost
35
+ ttc_cost_params: parameters for the TTCCost (i.e., collision cost between ego and the other
36
+ actor)
37
+ """
38
+
39
+ dt: float
40
+ num_steps: int
41
+ num_steps_future: int
42
+ acceleration_std_x_m_s2: float
43
+ acceleration_std_y_m_s2: float
44
+
45
+ risk_estimator_params: dict
46
+ solver_params: CrossEntropySolverParams
47
+ tracking_cost_params: TrackingCostParams
48
+ ttc_cost_params: TTCCostParams
49
+
50
+ @staticmethod
51
+ def from_config(cfg: Config):
52
+ return MPCPlannerParams(
53
+ cfg.dt,
54
+ cfg.num_steps,
55
+ cfg.num_steps_future,
56
+ cfg.acceleration_std_x_m_s2,
57
+ cfg.acceleration_std_y_m_s2,
58
+ cfg.risk_estimator,
59
+ CrossEntropySolverParams.from_config(cfg),
60
+ TrackingCostParams.from_config(cfg),
61
+ TTCCostParams.from_config(cfg),
62
+ )
63
+
64
+
65
+ class MPCPlanner:
66
+ """MPC Planner with a Cross Entropy solver
67
+
68
+ Args:
69
+ params: MPCPlannerParams object
70
+ predictor: LitTrajectoryPredictor object
71
+ normalizer: function that takes in an unnormalized trajectory and that outputs the
72
+ normalized trajectory and the offset in this order
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ params: MPCPlannerParams,
78
+ predictor: LitTrajectoryPredictor,
79
+ normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
80
+ ) -> None:
81
+
82
+ self.params = params
83
+ self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt)
84
+ self.control_input_mean_init = torch.zeros(
85
+ 1, params.num_steps_future, self.dynamics_model.control_dim
86
+ )
87
+ self.control_input_std_init = torch.Tensor(
88
+ [
89
+ params.acceleration_std_x_m_s2,
90
+ params.acceleration_std_y_m_s2,
91
+ ]
92
+ ).expand_as(self.control_input_mean_init)
93
+ self.solver = CrossEntropySolver(
94
+ params=params.solver_params,
95
+ dynamics_model=self.dynamics_model,
96
+ control_input_mean=self.control_input_mean_init,
97
+ control_input_std=self.control_input_std_init,
98
+ tracking_cost_function=TrackingCost(params.tracking_cost_params),
99
+ interaction_cost_function=TTCCostTorch(params.ttc_cost_params),
100
+ risk_estimator=get_risk_estimator(params.risk_estimator_params),
101
+ )
102
+ self.predictor = predictor
103
+ self.normalizer = normalizer
104
+
105
+ self._ego_state_history = []
106
+ self._ego_state_target_trajectory = None
107
+ self._ego_state_planned_trajectory = None
108
+
109
+ self._ado_state_history = []
110
+ self._latest_ado_position_future_samples = None
111
+
112
+ def replan(
113
+ self,
114
+ current_ado_state: AbstractState,
115
+ current_ego_state: AbstractState,
116
+ target_velocity: torch.Tensor,
117
+ num_prediction_samples: int = 1,
118
+ risk_level: float = 0.0,
119
+ resample_prediction: bool = False,
120
+ risk_in_predictor: bool = False,
121
+ ) -> None:
122
+ """Performs re-planning given the current_ado_position, current_ego_state, and
123
+ target_velocity. Updates ego_state_planned_trajectory. Note that all the information given
124
+ to the solver.solve(...) is expressed in the ego-centric frame, whose origin is the initial
125
+ ego position in ego_state_history and the x-direction is parallel to the initial ego
126
+ velocity.
127
+
128
+ Args:
129
+ current_ado_position: ado state
130
+ current_ego_state: ego state
131
+ target_velocity: ((1), 2) tensor
132
+ num_prediction_samples (optional): number of prediction samples. Defaults to 1.
133
+ risk_level (optional): a risk-level float for the entire prediction-planning pipeline.
134
+ If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0.
135
+ resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy
136
+ iteration. Defaults to False.
137
+ risk_in_predictor (optional): If True, risk-biased prediction is used and the solver
138
+ becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes
139
+ risk-sensitive. Defaults to False.
140
+ """
141
+ self._update_ado_state_history(current_ado_state)
142
+ self._update_ego_state_history(current_ego_state)
143
+ self._update_ego_state_target_trajectory(current_ego_state, target_velocity)
144
+ if not self.ado_state_history.shape[-1] < self.params.num_steps:
145
+ self.solver.solve(
146
+ self.predictor,
147
+ self._map_to_ego_centric_frame(self.ego_state_history),
148
+ self._map_to_ego_centric_frame(self._ego_state_target_trajectory),
149
+ self._map_to_ego_centric_frame(self.ado_state_history),
150
+ self.normalizer,
151
+ num_prediction_samples=num_prediction_samples,
152
+ risk_level=risk_level,
153
+ resample_prediction=resample_prediction,
154
+ risk_in_predictor=risk_in_predictor,
155
+ )
156
+ ego_state_planned_trajectory_in_ego_frame = self.dynamics_model.simulate(
157
+ self._map_to_ego_centric_frame(self.ego_state_history[..., -1]),
158
+ self.solver.control_sequence,
159
+ )
160
+ self._ego_state_planned_trajectory = self._map_to_world_frame(
161
+ ego_state_planned_trajectory_in_ego_frame
162
+ )
163
+ latest_ado_position_future_samples_in_ego_frame = (
164
+ self.solver.fetch_latest_prediction()
165
+ )
166
+ if latest_ado_position_future_samples_in_ego_frame is not None:
167
+ self._latest_ado_position_future_samples = self._map_to_world_frame(
168
+ latest_ado_position_future_samples_in_ego_frame
169
+ )
170
+ else:
171
+ self._latest_ado_position_future_samples = None
172
+
173
+ def get_planned_next_ego_state(self) -> AbstractState:
174
+ """Returns the next ego state according to the ego_state_planned_trajectory
175
+
176
+ Returns:
177
+ Planned state
178
+ """
179
+ assert (
180
+ self._ego_state_planned_trajectory is not None
181
+ ), "call self.replan(...) first"
182
+ return self._ego_state_planned_trajectory[..., 0]
183
+
184
+ def reset(self) -> None:
185
+ """Resets the planner's internal state. This will fully reset the solver's internal state,
186
+ including solver.control_input_mean_init and solver.control_input_std_init."""
187
+ self.solver.control_input_mean_init = (
188
+ self.control_input_mean_init.detach().clone()
189
+ )
190
+ self.solver.control_input_std_init = (
191
+ self.control_input_std_init.detach().clone()
192
+ )
193
+ self.solver.reset()
194
+
195
+ self._ego_state_history = []
196
+ self._ego_state_target_trajectory = None
197
+ self._ego_state_planned_trajectory = None
198
+
199
+ self._ado_state_history = []
200
+ self._latest_ado_position_future_samples = None
201
+
202
+ def fetch_latest_prediction(self) -> torch.Tensor:
203
+ if self._latest_ado_position_future_samples is not None:
204
+ return self._latest_ado_position_future_samples
205
+ else:
206
+ return None
207
+
208
+ @property
209
+ def ego_state_history(self) -> torch.Tensor:
210
+ """Returns ego_state_history as a concatenated tensor
211
+ Returns:
212
+ ego_state_history tensor
213
+ """
214
+ assert len(self._ego_state_history) > 0
215
+ return to_state(
216
+ torch.stack(
217
+ [ego_state.get_states(4) for ego_state in self._ego_state_history],
218
+ dim=-2,
219
+ ),
220
+ self.params.dt,
221
+ )
222
+
223
+ @property
224
+ def ado_state_history(self) -> torch.Tensor:
225
+ """Returns ado_position_history as a concatenated tensor
226
+
227
+ Returns:
228
+ ado_position_history tensor
229
+ """
230
+ assert len(self._ado_state_history) > 0
231
+ return to_state(
232
+ torch.stack(
233
+ [ado_state.get_states(4) for ado_state in self._ado_state_history],
234
+ dim=-2,
235
+ ),
236
+ self.params.dt,
237
+ )
238
+
239
+ def _update_ego_state_history(self, current_ego_state: AbstractState) -> None:
240
+ """Updates ego_state_history with the current_ego_state
241
+
242
+ Args:
243
+ current_ego_state: (1, state_dim) tensor
244
+ """
245
+
246
+ if len(self._ego_state_history) >= self.params.num_steps:
247
+ self._ego_state_history = self._ego_state_history[1:]
248
+ self._ego_state_history.append(current_ego_state)
249
+ assert len(self._ego_state_history) <= self.params.num_steps
250
+
251
+ def _update_ado_state_history(self, current_ado_state: AbstractState) -> None:
252
+ """Updates ego_state_history with the current_ado_position
253
+
254
+ Args:
255
+ current_ado_state states of the current non-ego vehicles
256
+ """
257
+
258
+ if len(self._ado_state_history) >= self.params.num_steps:
259
+ self._ado_state_history = self._ado_state_history[1:]
260
+ self._ado_state_history.append(current_ado_state)
261
+ assert len(self._ado_state_history) <= self.params.num_steps
262
+
263
+ def _update_ego_state_target_trajectory(
264
+ self, current_ego_state: AbstractState, target_velocity: torch.Tensor
265
+ ) -> None:
266
+ """Updates ego_state_target_trajectory based on the current_ego_state and the target_velocity
267
+
268
+ Args:
269
+ current_ego_state: state
270
+ target_velocity: (1, 2) tensor
271
+ """
272
+
273
+ target_displacement = self.params.dt * target_velocity
274
+ target_position_list = [current_ego_state.position]
275
+ for time_idx in range(self.params.num_steps_future):
276
+ target_position_list.append(target_position_list[-1] + target_displacement)
277
+ target_position_list = target_position_list[1:]
278
+ target_position = torch.cat(target_position_list, dim=-2)
279
+ target_state = to_state(
280
+ torch.cat(
281
+ (target_position, target_velocity.expand_as(target_position)), dim=-1
282
+ ),
283
+ self.params.dt,
284
+ )
285
+ self._ego_state_target_trajectory = target_state
286
+
287
+ def _map_to_ego_centric_frame(
288
+ self, trajectory_in_world_frame: AbstractState
289
+ ) -> torch.Tensor:
290
+ """Maps trajectory epxressed in the world frame to the ego-centric frame, whose origin is
291
+ the initial ego position in ego_state_history and the x-direction is parallel to the initial
292
+ ego velocity
293
+
294
+ Args:
295
+ trajectory: sequence of states
296
+
297
+ Returns:
298
+ trajectory mapped to the ego-centric frame
299
+ """
300
+ # If trajectory_in_world_frame is of shape (..., state_dim) then use the associated
301
+ # dynamics model in translate_position and rotate_angle. Otherwise assume that th
302
+ # trajectory is in the 2D position space.
303
+
304
+ ego_pos_init = self.ego_state_history.position[..., -1, :]
305
+ ego_vel_init = self.ego_state_history.velocity[..., -1, :]
306
+ ego_rot_init = torch.atan2(ego_vel_init[..., 1], ego_vel_init[..., 0])
307
+ trajectory_in_ego_frame = trajectory_in_world_frame.translate(
308
+ -ego_pos_init
309
+ ).rotate(-ego_rot_init)
310
+ return trajectory_in_ego_frame
311
+
312
+ def _map_to_world_frame(
313
+ self, trajectory_in_ego_frame: torch.Tensor
314
+ ) -> torch.Tensor:
315
+ """Maps trajectory epxressed in the ego-centric frame to the world frame
316
+
317
+ Args:
318
+ trajectory_in_ego_frame: (..., 2) position trajectory or (..., markov_state_dim) state
319
+ trajectory expressed in the ego-centric frame, whose origin is the initial ego
320
+ position in ego_state_history and the x-direction is parallel to the initial ego
321
+ velocity
322
+
323
+ Returns:
324
+ trajectory mapped to the world frame
325
+ """
326
+ # state starts with x, y, angle
327
+ ego_pos_init = self.ego_state_history.position[..., -1, :]
328
+ ego_rot_init = self.ego_state_history.angle[..., -1, :]
329
+ trajectory_in_world_frame = trajectory_in_ego_frame.rotate(
330
+ ego_rot_init
331
+ ).translate(ego_pos_init)
332
+ return trajectory_in_world_frame
risk_biased/mpc_planner/planner_cost.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from mmcv import Config
6
+
7
+
8
+ @dataclass
9
+ class TrackingCostParams:
10
+ scale_longitudinal: float
11
+ scale_lateral: float
12
+ reduce: str
13
+
14
+ @staticmethod
15
+ def from_config(cfg: Config):
16
+ return TrackingCostParams(
17
+ scale_longitudinal=cfg.tracking_cost_scale_longitudinal,
18
+ scale_lateral=cfg.tracking_cost_scale_lateral,
19
+ reduce=cfg.tracking_cost_reduce,
20
+ )
21
+
22
+
23
+ class TrackingCost:
24
+ """Quadratic Trajectory Tracking Cost
25
+
26
+ Args:
27
+ params: tracking cost parameters
28
+ """
29
+
30
+ def __init__(self, params: TrackingCostParams) -> None:
31
+ self.scale_longitudinal = params.scale_longitudinal
32
+ self.scale_lateral = params.scale_lateral
33
+ assert params.reduce in [
34
+ "min",
35
+ "max",
36
+ "mean",
37
+ "now",
38
+ "final",
39
+ ], "unsupported reduce type"
40
+ self._reduce_fun_name = params.reduce
41
+
42
+ def __call__(
43
+ self,
44
+ ego_position_trajectory: torch.Tensor,
45
+ target_position_trajectory: torch.Tensor,
46
+ target_velocity_trajectory: torch.Tensor,
47
+ ) -> torch.Tensor:
48
+ """Computes quadratic tracking cost
49
+
50
+ Args:
51
+ ego_position_trajectory: (some_shape, num_some_steps, 2) tensor of ego
52
+ position trajectory
53
+ target_position_trajectory: (some_shape, num_some_steps, 2) tensor of
54
+ ego target position trajectory
55
+ target_velocity_trajectory: (some_shape, num_some_steps, 2) tensor of
56
+ ego target velocity trajectory
57
+
58
+ Returns:
59
+ (some_shape) cost
60
+ """
61
+ cost_matrix = self._get_quadratic_cost_matrix(target_velocity_trajectory)
62
+ cost = (
63
+ (
64
+ (ego_position_trajectory - target_position_trajectory).unsqueeze(-2)
65
+ @ cost_matrix
66
+ @ (ego_position_trajectory - target_position_trajectory).unsqueeze(-1)
67
+ )
68
+ .squeeze(-1)
69
+ .squeeze(-1)
70
+ )
71
+ return self._reduce(cost, dim=-1)
72
+
73
+ def _reduce(self, cost: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
74
+ """Reduces the cost tensor based on self._reduce_fun_name
75
+
76
+ Args:
77
+ cost: cost tensor of some shape where the last dimension represents time
78
+ dim (optional): tensor dimension to be reduced. Defaults to None.
79
+
80
+ Returns:
81
+ reduced cost tensor
82
+ """
83
+ if self._reduce_fun_name == "min":
84
+ return torch.min(cost, dim=dim)[0] if dim is not None else torch.min(cost)
85
+ if self._reduce_fun_name == "max":
86
+ return torch.max(cost, dim=dim)[0] if dim is not None else torch.max(cost)
87
+ if self._reduce_fun_name == "mean":
88
+ return torch.mean(cost, dim=dim) if dim is not None else torch.mean(cost)
89
+ if self._reduce_fun_name == "now":
90
+ return cost[..., 0]
91
+ if self._reduce_fun_name == "final":
92
+ return cost[..., -1]
93
+
94
+ def _get_quadratic_cost_matrix(
95
+ self, target_velocity_trajectory: torch.Tensor, eps: float = 1e-8
96
+ ) -> torch.Tensor:
97
+ """Gets quadratic cost matrix based on target velocity direction per time step.
98
+ If target velocity is 0 in norm, then all zero tensor is returned for that time step.
99
+
100
+ Args:
101
+ target_velocity_trajectory: (some_shape, num_some_steps, 2) tensor of
102
+ ego target velocity trajectory
103
+ eps (optional): small positive number to ensure numerical stability. Defaults to
104
+ 1e-8.
105
+
106
+ Returns:
107
+ (some_shape, num_some_steps, 2, 2) quadratic cost matrix
108
+ """
109
+ longitudinal_direction = (
110
+ target_velocity_trajectory
111
+ / (
112
+ torch.linalg.norm(target_velocity_trajectory, dim=-1).unsqueeze(-1)
113
+ + eps
114
+ )
115
+ ).unsqueeze(-1)
116
+ rotation_90_deg = torch.Tensor([[[0.0, -1.0], [1.0, 0]]])
117
+ lateral_direction = rotation_90_deg @ longitudinal_direction
118
+ orthogonal_matrix = torch.cat(
119
+ (longitudinal_direction, lateral_direction), dim=-1
120
+ )
121
+ eigen_matrix = torch.Tensor(
122
+ [[[self.scale_longitudinal, 0.0], [0.0, self.scale_lateral]]]
123
+ )
124
+ cost_matrix = (
125
+ orthogonal_matrix @ eigen_matrix @ orthogonal_matrix.transpose(-1, -2)
126
+ )
127
+ return cost_matrix
risk_biased/mpc_planner/solver.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Dict, List, Optional, Tuple
3
+
4
+ from mmcv import Config
5
+ import numpy as np
6
+ import torch
7
+
8
+ from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
9
+ from risk_biased.mpc_planner.planner_cost import TrackingCost
10
+ from risk_biased.predictors.biased_predictor import LitTrajectoryPredictor
11
+ from risk_biased.utils.cost import BaseCostTorch
12
+ from risk_biased.utils.planner_utils import (
13
+ AbstractState,
14
+ to_state,
15
+ evaluate_risk,
16
+ get_interaction_cost,
17
+ )
18
+ from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator
19
+
20
+
21
+ @dataclass
22
+ class CrossEntropySolverParams:
23
+ """Dataclass for Cross Entropy Solver Parameters
24
+
25
+ Args:
26
+ num_control_samples: number of Monte Carlo samples for control input
27
+ num_elite: number of elite samples
28
+ iter_max: maximum iteration number
29
+ smoothing_factor: smoothing factor in (0, 1) used to update the mean and the std of the
30
+ control input distribution for the next iteration. If 0, the updated distribution is
31
+ independent of the previous iteration. If 1, the updated distribution is the same as the
32
+ previous iteration.
33
+ mean_warm_start: internally saves control_input_mean of the last iteration of the current
34
+ solve, so that control_input_mean will be warm-started in the next solve
35
+ """
36
+
37
+ num_control_samples: int
38
+ num_elite: int
39
+ iter_max: int
40
+ smoothing_factor: float
41
+ mean_warm_start: bool
42
+ dt: float
43
+
44
+ @staticmethod
45
+ def from_config(cfg: Config):
46
+ return CrossEntropySolverParams(
47
+ cfg.num_control_samples,
48
+ cfg.num_elite,
49
+ cfg.iter_max,
50
+ cfg.smoothing_factor,
51
+ cfg.mean_warm_start,
52
+ cfg.dt,
53
+ )
54
+
55
+
56
+ class CrossEntropySolver:
57
+ """Cross Entropy Solver for MPC Planner
58
+
59
+ Args:
60
+ params: CrossEntropySolverParams object
61
+ dynamics_model: dynamics model for control
62
+ control_input_mean: (num_agents, num_steps_future, control_dim) tensor of control input mean
63
+ control_input_std: (num_agents, num_steps_future, control_dim) tensor of control input std
64
+ tracking_cost_function: deterministic tracking cost that does not involve ado
65
+ intraction_cost_function: interaction cost function between ego and (stochastic) ado
66
+ risk_estimator (optional): Monte Carlo risk estimator for risk computation. If None,
67
+ risk-neutral expecation is used for selectoin of elites. Defaults to None.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ params: CrossEntropySolverParams,
73
+ dynamics_model: PositionVelocityDoubleIntegrator,
74
+ control_input_mean: torch.Tensor,
75
+ control_input_std: torch.Tensor,
76
+ tracking_cost_function: TrackingCost,
77
+ interaction_cost_function: BaseCostTorch,
78
+ risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
79
+ ) -> None:
80
+ self.params = params
81
+
82
+ self.control_input_mean_init = control_input_mean.detach().clone()
83
+ self.control_input_std_init = control_input_std.detach().clone()
84
+ assert (
85
+ self.control_input_mean_init.shape == self.control_input_std_init.shape
86
+ ), "control input mean and std must have the same size"
87
+ assert (
88
+ self.control_input_mean_init.shape[-1] == dynamics_model.control_dim
89
+ ), f"control dimension must be {dynamics_model.control_dim}"
90
+
91
+ self.dynamics_model = dynamics_model
92
+ self.tracking_cost = tracking_cost_function
93
+ self.interaction_cost = interaction_cost_function
94
+ self.risk_estimator = risk_estimator
95
+
96
+ self._iter_current = None
97
+ self._control_input_mean = None
98
+ self._control_input_std = None
99
+
100
+ self._latest_ado_position_future_samples = None
101
+
102
+ self.reset()
103
+
104
+ def reset(self) -> None:
105
+ """Resets the solver's internal state"""
106
+ self._iter_current = 0
107
+ self._control_input_mean = self.control_input_mean_init.clone()
108
+ self._control_input_std = self.control_input_std_init.clone()
109
+ self._latest_ado_position_future_samples = None
110
+
111
+ def step(
112
+ self,
113
+ ego_state_history: AbstractState,
114
+ ego_state_target_trajectory: AbstractState,
115
+ ado_state_future_samples: AbstractState,
116
+ weights: torch.Tensor,
117
+ verbose: bool = False,
118
+ risk_level: float = 0.0,
119
+ ) -> Dict:
120
+ """Performs one iteration step of the Cross Entropy Method
121
+
122
+ Args:
123
+ ego_state_history: (num_agents, num_steps) ego state history
124
+ ego_state_target_trajectory: (num_agents, num_steps_future) ego target
125
+ state trajectory
126
+ ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
127
+ predicted ado trajectory samples
128
+ weights: (num_prediction_samples, num_agents) prediction sample weight
129
+ verbose (optional): Print progress. Defaults to False.
130
+ risk_level (optional): a risk-level float for the solver. If 0.0, risk-neutral
131
+ expectation is used for selection of elites. Defaults to 0.0.
132
+
133
+ Return:
134
+ Dictionary containing information about this solver step.
135
+ """
136
+
137
+ self._iter_current += 1
138
+ ego_control_input = torch.normal(
139
+ self._control_input_mean.expand(
140
+ self.params.num_control_samples, -1, -1, -1
141
+ ),
142
+ self._control_input_std.expand(self.params.num_control_samples, -1, -1, -1),
143
+ )
144
+ if verbose:
145
+ print(f"**Cross Entropy Iteration {self._iter_current}")
146
+ print(
147
+ f"****Drawring ego's control input samples of {ego_control_input.shape}"
148
+ )
149
+ ego_state_current = ego_state_history[..., -1]
150
+ ego_state_future = self.dynamics_model.simulate(
151
+ ego_state_current, ego_control_input
152
+ )
153
+ if verbose:
154
+ print(f"****Simulating ego's future state trajectory")
155
+
156
+ # state starts with x, y, angle, vx, vy
157
+ tracking_cost = self.tracking_cost(
158
+ ego_state_future.position,
159
+ ego_state_target_trajectory.position,
160
+ ego_state_target_trajectory.velocity,
161
+ )
162
+ if verbose:
163
+ print(
164
+ f"****Computing tracking cost of {tracking_cost.shape} for the control input samples"
165
+ )
166
+
167
+ # state starts with x, y
168
+ interaction_cost = get_interaction_cost(
169
+ ego_state_future,
170
+ ado_state_future_samples,
171
+ self.interaction_cost,
172
+ )
173
+ if verbose:
174
+ print(
175
+ f"****Computing interaction cost of {interaction_cost.shape} for the control input samples"
176
+ )
177
+ interaction_risk = evaluate_risk(
178
+ risk_level,
179
+ interaction_cost,
180
+ weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost),
181
+ self.risk_estimator,
182
+ )
183
+
184
+ total_risk = interaction_risk + tracking_cost
185
+ elite_ego_control_input, elite_total_risk = self._get_elites(
186
+ ego_control_input, total_risk
187
+ )
188
+ if verbose:
189
+ print(f"****Selecting {self.params.num_elite} elite samples")
190
+ print(f"****Elite Total_Risk Information: {elite_total_risk}")
191
+
192
+ info = dict(
193
+ iteration=self._iter_current,
194
+ control_input_mean=self._control_input_mean.detach().cpu().numpy().copy(),
195
+ control_input_std=self._control_input_std.detach().cpu().numpy().copy(),
196
+ ego_state_future=ego_state_future.get_states(5)
197
+ .detach()
198
+ .cpu()
199
+ .numpy()
200
+ .copy(),
201
+ ado_state_future_samples=ado_state_future_samples.get_states(5)
202
+ .detach()
203
+ .cpu()
204
+ .numpy()
205
+ .copy(),
206
+ sample_weights=weights.detach().cpu().numpy().copy(),
207
+ tracking_cost=tracking_cost.detach().cpu().numpy().copy(),
208
+ interaction_cost=interaction_cost.detach().cpu().numpy().copy(),
209
+ total_risk=total_risk.detach().cpu().numpy().copy(),
210
+ )
211
+
212
+ self._update_control_distribution(elite_ego_control_input)
213
+ if verbose:
214
+ print("****Updating ego's control distribution")
215
+
216
+ return info
217
+
218
+ def solve(
219
+ self,
220
+ predictor: LitTrajectoryPredictor,
221
+ ego_state_history: AbstractState,
222
+ ego_state_target_trajectory: AbstractState,
223
+ ado_state_history: AbstractState,
224
+ normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
225
+ num_prediction_samples: int = 1,
226
+ verbose: bool = False,
227
+ risk_level: float = 0.0,
228
+ resample_prediction: bool = False,
229
+ risk_in_predictor: bool = False,
230
+ ) -> List[Dict]:
231
+ """Performs Cross Entropy optimization of ego's control input
232
+
233
+ Args:
234
+ predictor: LitTrajectoryPredictor object
235
+ ego_state_history: (num_agents, num_steps, state_dim) ego state history
236
+ ego_state_target_trajectory: (num_agents, num_steps_future, state_dim) ego target
237
+ state trajectory
238
+ ado_state_history: (num_agents, num_steps, state_dim) ado state history
239
+ normalizer: function that takes in an unnormalized trajectory and that outputs the
240
+ normalized trajectory and the offset in this order
241
+ num_prediction_samples: number of prediction samples. Defaults to 1.
242
+ verbose (optional): Print progress. Defaults to False.
243
+ risk_level (optional): a risk-level float for the entire prediction-planning pipeline.
244
+ If 0.0, risk-neutral prediction and planning are used. Defaults to 0.0.
245
+ resample_prediction (optional): If True, prediction is re-sampled in each cross-entropy
246
+ iteration. Defaults to False.
247
+ risk_in_predictor (optional): If True, risk-biased prediction is used and the solver
248
+ becomes risk-neutral. If False, risk-neutral prediction is used and the solver becomes
249
+ risk-sensitive. Defaults to False.
250
+
251
+ Return:
252
+ List of dictionaries each containing information about the corresponding solver step.
253
+ """
254
+ if risk_level == 0.0:
255
+ risk_level_planner, risk_level_predictor = 0.0, 0.0
256
+ else:
257
+ if risk_in_predictor:
258
+ risk_level_planner, risk_level_predictor = 0.0, risk_level
259
+ else:
260
+ risk_level_planner, risk_level_predictor = risk_level, 0.0
261
+ self.reset()
262
+ infos = []
263
+ ego_state_future = self.dynamics_model.simulate(
264
+ ego_state_history[..., -1],
265
+ self.control_sequence,
266
+ )
267
+ for iter in range(self.params.iter_max):
268
+ assert iter == self._iter_current
269
+ if resample_prediction or self._iter_current == 0:
270
+ ado_state_future_samples, weights = self.sample_prediction(
271
+ predictor,
272
+ ado_state_history,
273
+ normalizer,
274
+ ego_state_history,
275
+ ego_state_future,
276
+ num_prediction_samples,
277
+ risk_level_predictor,
278
+ )
279
+ self._latest_ado_position_future_samples = ado_state_future_samples
280
+ info = self.step(
281
+ ego_state_history,
282
+ ego_state_target_trajectory,
283
+ ado_state_future_samples,
284
+ weights,
285
+ verbose=verbose,
286
+ risk_level=risk_level_planner,
287
+ )
288
+ infos.append(info)
289
+ if self.params.mean_warm_start:
290
+ self.control_input_mean_init[:, :-1] = (
291
+ self._control_input_mean[:, 1:].detach().clone()
292
+ )
293
+ return infos
294
+
295
+ @property
296
+ def control_sequence(self) -> torch.Tensor:
297
+ """Returns the planned control sequence, which is a detached copy of the control input mean
298
+ tensor
299
+
300
+ Returns:
301
+ (num_steps_future, control_dim) control sequence tensor
302
+ """
303
+ return self._control_input_mean.detach().clone()
304
+
305
+ @staticmethod
306
+ def sample_prediction(
307
+ predictor: LitTrajectoryPredictor,
308
+ ado_state_history: AbstractState,
309
+ normalizer: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]],
310
+ ego_state_history: AbstractState,
311
+ ego_state_future: AbstractState,
312
+ num_prediction_samples: int = 1,
313
+ risk_level: float = 0.0,
314
+ ) -> Tuple[AbstractState, torch.Tensor]:
315
+ """Sample prediction from the predictor given the history, normalizer, and the desired
316
+ risk-level
317
+
318
+ Args:
319
+ predictor: LitTrajectoryPredictor object
320
+ ado_state_history: (num_agents, num_steps, state_dim) tensor of ado position history
321
+ normalizer: function that takes in an unnormalized trajectory and that outputs the
322
+ normalized trajectory and the offset in this order
323
+ ego_state_history: (num_agents, num_steps , state_dim) tensor of ego position history or future
324
+ ego_state_future: (num_agents, num_steps_future, state_dim) tensor of ego position history or future
325
+ num_prediction_samples (optional): number of prediction samples. Defaults to 1.
326
+ risk_level (optional): a risk-level float for the predictor. If 0.0, risk-neutral
327
+ prediction is sampled. Defaults to 0.0.
328
+
329
+ Returns:
330
+ state samples of shape (num_agents, num_prediction_samples, num_steps_future)
331
+ probability weights of the samples of shape (num_agents, num_prediction_samples)
332
+ """
333
+ ado_position_history_normalized, offset = normalizer(
334
+ ado_state_history.get_states(predictor.dynamic_state_dim)
335
+ .unsqueeze(0)
336
+ .expand(num_prediction_samples, -1, -1, -1)
337
+ )
338
+
339
+ x = ado_position_history_normalized.clone()
340
+ mask_x = torch.ones_like(x[..., 0])
341
+ map = torch.empty(num_prediction_samples, 0, 0, 2, device=x.device)
342
+ mask_map = torch.empty(num_prediction_samples, 0, 0, device=x.device)
343
+
344
+ batch = (
345
+ x,
346
+ mask_x,
347
+ map,
348
+ mask_map,
349
+ offset,
350
+ ego_state_history.get_states(predictor.dynamic_state_dim)
351
+ .unsqueeze(0)
352
+ .expand(num_prediction_samples, -1, -1, -1),
353
+ ego_state_future.get_states(predictor.dynamic_state_dim)
354
+ .unsqueeze(0)
355
+ .expand(num_prediction_samples, -1, -1, -1),
356
+ )
357
+
358
+ ado_position_future_samples, weights = predictor.predict_step(
359
+ batch,
360
+ 0,
361
+ risk_level=risk_level,
362
+ return_weights=True,
363
+ )
364
+ ado_position_future_samples = ado_position_future_samples.detach().cpu()
365
+ weights = weights.detach().cpu()
366
+
367
+ return to_state(ado_position_future_samples, predictor.dt), weights
368
+
369
+ def fetch_latest_prediction(self):
370
+ if self._latest_ado_position_future_samples is not None:
371
+ return self._latest_ado_position_future_samples
372
+ else:
373
+ return None
374
+
375
+ def _get_elites(
376
+ self, control_input: torch.Tensor, risk: torch.Tensor
377
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
378
+ """Selects elite control input based on corresponding risk (lower the better)
379
+
380
+ Args:
381
+ control_input: (num_control_samples, num_agents, num_steps_future, control_dim) control samples
382
+ risk: (num_control_samples, num_agents) risk tensor
383
+
384
+ Returns:
385
+ elite_control_input: (num_elite, num_agents, num_steps_future, control_dim) elite control
386
+ elite_risk: (num_elite, num_agents) elite risk
387
+ """
388
+ num_control_samples = self.params.num_control_samples
389
+ assert (
390
+ control_input.shape[0] == num_control_samples
391
+ ), f"size of control_input tensor must be {num_control_samples} at dimension 0"
392
+ assert (
393
+ risk.shape[0] == num_control_samples
394
+ ), f"size of risk tensor must be {num_control_samples} at dimension 0"
395
+
396
+ _, sorted_risk_indices = torch.sort(risk, dim=0)
397
+ elite_control_input = control_input[
398
+ sorted_risk_indices[: self.params.num_elite], np.arange(risk.shape[1])
399
+ ]
400
+ elite_risk = risk[
401
+ sorted_risk_indices[: self.params.num_elite], np.arange(risk.shape[1])
402
+ ]
403
+ return elite_control_input, elite_risk
404
+
405
+ def _update_control_distribution(self, elite_control_input: torch.Tensor) -> None:
406
+ """Updates control input distribution using elites
407
+
408
+ Args:
409
+ elite_control_input: (num_elite, num_steps_future, control_dim) elite control
410
+ """
411
+ num_elite, smoothing_factor = (
412
+ self.params.num_elite,
413
+ self.params.smoothing_factor,
414
+ )
415
+ assert (
416
+ elite_control_input.shape[0] == num_elite
417
+ ), f"size of elite_control_input tensor must be {num_elite} at dimension 0"
418
+
419
+ elite_control_input_mean = elite_control_input.mean(dim=0, keepdim=False)
420
+ if num_elite < 2:
421
+ elite_control_input_std = torch.zeros_like(elite_control_input_mean)
422
+ else:
423
+ elite_control_input_std = elite_control_input.std(dim=0, keepdim=False)
424
+ self._control_input_mean = (
425
+ 1.0 - smoothing_factor
426
+ ) * elite_control_input_mean + smoothing_factor * self._control_input_mean
427
+ self._control_input_std = (
428
+ 1.0 - smoothing_factor
429
+ ) * elite_control_input_std + smoothing_factor * self._control_input_std
risk_biased/predictors/biased_predictor.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+
5
+
6
+ from einops import repeat
7
+ from mmcv import Config
8
+ import pytorch_lightning as pl
9
+ import torch
10
+
11
+ from risk_biased.models.cvae_params import CVAEParams
12
+ from risk_biased.models.biased_cvae_model import (
13
+ cvae_factory,
14
+ )
15
+
16
+ from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
17
+ from risk_biased.utils.risk import get_risk_estimator
18
+ from risk_biased.utils.risk import get_risk_level_sampler
19
+
20
+
21
+ @dataclass
22
+ class LitTrajectoryPredictorParams:
23
+ """
24
+ cvae_params: CVAEParams class defining the necessary parameters for the CVAE model
25
+ risk distribution: dict of string and values defining the risk distribution to use
26
+ risk_estimator: dict of string and values defining the risk estimator to use
27
+ kl_weight: float defining the weight of the KL term in the loss function
28
+ kl_threshold: float defining the threshold to apply when computing kl divergence (avoid posterior collapse)
29
+ risk_weight: float defining the weight of the risk term in the loss function
30
+ n_mc_samples_risk: int defining the number of Monte Carlo samples to use when estimating the risk
31
+ n_mc_samples_biased: int defining the number of Monte Carlo samples to use when estimating the expected biased cost
32
+ dt: float defining the duration between two consecutive time steps
33
+ learning_rate: float defining the learning rate for the optimizer
34
+ use_risk_constraint: bool defining whether to use the risk constrained optimization procedure
35
+ risk_constraint_update_every_n_epoch: int defining the number of epochs between two risk weight updates
36
+ risk_constraint_weight_update_factor: float defining the factor by which the risk weight is multiplied at each update
37
+ risk_constraint_weight_maximum: float defining the maximum value of the risk weight
38
+ num_samples_min_fde: int defining the number of samples to use when estimating the minimum FDE
39
+ condition_on_ego_future: bool defining whether to condition the biasing on the ego future trajectory (else on the ego past)
40
+
41
+ """
42
+
43
+ cvae_params: CVAEParams
44
+ risk_distribution: dict
45
+ risk_estimator: dict
46
+ kl_weight: float
47
+ kl_threshold: float
48
+ risk_weight: float
49
+ n_mc_samples_risk: int
50
+ n_mc_samples_biased: int
51
+ dt: float
52
+ learning_rate: float
53
+ use_risk_constraint: bool
54
+ risk_constraint_update_every_n_epoch: int
55
+ risk_constraint_weight_update_factor: float
56
+ risk_constraint_weight_maximum: float
57
+ num_samples_min_fde: int
58
+ condition_on_ego_future: bool
59
+
60
+ @staticmethod
61
+ def from_config(cfg: Config):
62
+ cvae_params = CVAEParams.from_config(cfg)
63
+ return LitTrajectoryPredictorParams(
64
+ risk_distribution=cfg.risk_distribution,
65
+ risk_estimator=cfg.risk_estimator,
66
+ kl_weight=cfg.kl_weight,
67
+ kl_threshold=cfg.kl_threshold,
68
+ risk_weight=cfg.risk_weight,
69
+ n_mc_samples_risk=cfg.n_mc_samples_risk,
70
+ n_mc_samples_biased=cfg.n_mc_samples_biased,
71
+ dt=cfg.dt,
72
+ learning_rate=cfg.learning_rate,
73
+ cvae_params=cvae_params,
74
+ use_risk_constraint=cfg.use_risk_constraint,
75
+ risk_constraint_update_every_n_epoch=cfg.risk_constraint_update_every_n_epoch,
76
+ risk_constraint_weight_update_factor=cfg.risk_constraint_weight_update_factor,
77
+ risk_constraint_weight_maximum=cfg.risk_constraint_weight_maximum,
78
+ num_samples_min_fde=cfg.num_samples_min_fde,
79
+ condition_on_ego_future=cfg.condition_on_ego_future,
80
+ )
81
+
82
+
83
+ class LitTrajectoryPredictor(pl.LightningModule):
84
+ """Pytorch Lightning Module for Trajectory Prediction with the biased cvae model
85
+
86
+ Args:
87
+ params : dataclass object containing the necessary parameters
88
+ cost_params: dataclass object defining the TTC cost function
89
+ unnormalizer: function that takes in a trajectory and an offset and that outputs the
90
+ unnormalized trajectory
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ params: LitTrajectoryPredictorParams,
96
+ cost_params: TTCCostParams,
97
+ unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
98
+ ) -> None:
99
+ super().__init__()
100
+ model = cvae_factory(
101
+ params.cvae_params,
102
+ cost_function=TTCCostTorch(cost_params),
103
+ risk_estimator=get_risk_estimator(params.risk_estimator),
104
+ training_mode="cvae",
105
+ )
106
+ self.model = model
107
+ self.params = params
108
+ self._unnormalize_trajectory = unnormalizer
109
+ self.set_training_mode("cvae")
110
+
111
+ self.learning_rate = params.learning_rate
112
+ self.num_samples_min_fde = params.num_samples_min_fde
113
+
114
+ self.dynamic_state_dim = params.cvae_params.dynamic_state_dim
115
+ self.dt = params.cvae_params.dt
116
+
117
+ self.use_risk_constraint = params.use_risk_constraint
118
+ self.risk_weight = params.risk_weight
119
+ self.risk_weight_ratio = params.risk_weight / params.kl_weight
120
+ self.kl_weight = params.kl_weight
121
+ if self.use_risk_constraint:
122
+ self.risk_constraint_update_every_n_epoch = (
123
+ params.risk_constraint_update_every_n_epoch
124
+ )
125
+ self.risk_constraint_weight_update_factor = (
126
+ params.risk_constraint_weight_update_factor
127
+ )
128
+ self.risk_constraint_weight_maximum = params.risk_constraint_weight_maximum
129
+
130
+ self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
131
+
132
+ def set_training_mode(self, training_mode: str):
133
+ self.model.set_training_mode(training_mode)
134
+ self.partial_get_loss = partial(
135
+ self.model.get_loss,
136
+ kl_threshold=self.params.kl_threshold,
137
+ n_samples_risk=self.params.n_mc_samples_risk,
138
+ n_samples_biased=self.params.n_mc_samples_biased,
139
+ dt=self.params.dt,
140
+ unnormalizer=self._unnormalize_trajectory,
141
+ )
142
+
143
+ def _get_loss(
144
+ self,
145
+ x: torch.Tensor,
146
+ mask_x: torch.Tensor,
147
+ map: torch.Tensor,
148
+ mask_map: torch.Tensor,
149
+ y: torch.Tensor,
150
+ mask_y: torch.Tensor,
151
+ mask_loss: torch.Tensor,
152
+ x_ego: torch.Tensor,
153
+ y_ego: torch.Tensor,
154
+ offset: Optional[torch.Tensor] = None,
155
+ risk_level: Optional[torch.Tensor] = None,
156
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, ...]], dict]:
157
+ """Compute loss based on trajectory history x and future y
158
+
159
+ Args:
160
+ x: (batch_size, num_agents, num_steps, state_dim) tensor of history
161
+ mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
162
+ map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
163
+ mask_map: (batch_size, num_objects, object_sequence_length) tensor True where map features are good False where it is padding
164
+ y: (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
165
+ mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask.
166
+ mask_loss: (batch_size, num_agents, num_steps_future) tensor of bool mask set to True where the loss
167
+ should be computed and to False where it shouldn't
168
+ offset : (batch_size, num_agents, state_dim) offset position from ego
169
+ risk_level : (batch_size, num_agents) tensor of risk levels desired for future trajectories
170
+
171
+ Returns:
172
+ Union[torch.Tensor, Tuple[torch.Tensor, ...]]: (1,) loss tensor or tuple of
173
+ loss tensors
174
+ dict: dict that contains values to be logged
175
+ """
176
+ return self.partial_get_loss(
177
+ x=x,
178
+ mask_x=mask_x,
179
+ map=map,
180
+ mask_map=mask_map,
181
+ y=y,
182
+ mask_y=mask_y,
183
+ mask_loss=mask_loss,
184
+ offset=offset,
185
+ risk_level=risk_level,
186
+ x_ego=x_ego,
187
+ y_ego=y_ego,
188
+ risk_weight=self.risk_weight,
189
+ kl_weight=self.kl_weight,
190
+ )
191
+
192
+ def log_with_prefix(
193
+ self,
194
+ log_dict: dict,
195
+ prefix: Optional[str] = None,
196
+ on_step: Optional[bool] = None,
197
+ on_epoch: Optional[bool] = None,
198
+ ) -> None:
199
+ """log entries in log_dict while optinally adding "<prefix>/" to its keys
200
+
201
+ Args:
202
+ log_dict: dict that contains values to be logged
203
+ prefix: prefix to be added to keys
204
+ on_step: if True logs at this step. None auto-logs at the training_step but not
205
+ validation/test_step
206
+ on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test
207
+ step but not training_step
208
+ """
209
+ if prefix is None:
210
+ prefix = ""
211
+ else:
212
+ prefix += "/"
213
+
214
+ for (metric, value) in log_dict.items():
215
+ metric = prefix + metric
216
+ self.log(metric, value, on_step=on_step, on_epoch=on_epoch)
217
+
218
+ def configure_optimizers(
219
+ self,
220
+ ) -> Union[torch.optim.Optimizer, List[torch.optim.Optimizer]]:
221
+ """Configure optimizer for PyTorch-Lightning
222
+
223
+ Returns:
224
+ torch.optim.Optimizer: optimizer to be used for training
225
+ """
226
+ if isinstance(self.model.get_parameters(), list):
227
+ self._optimizers = [
228
+ torch.optim.Adam(params, lr=self.learning_rate)
229
+ for params in self.model.get_parameters()
230
+ ]
231
+ else:
232
+ self._optimizers = [
233
+ torch.optim.Adam(self.model.get_parameters(), lr=self.learning_rate)
234
+ ]
235
+ return self._optimizers
236
+
237
+ def training_step(
238
+ self,
239
+ batch: Tuple[
240
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
241
+ ],
242
+ batch_idx: int,
243
+ ) -> dict:
244
+ """Training step definition for PyTorch-Lightning
245
+
246
+ Args:
247
+ batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
248
+ (batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
249
+ (batch_size, num_agents, num_steps_future, state_dim), # future trajectory
250
+ (batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data
251
+ (batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted
252
+ (batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
253
+ (batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
254
+ (batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
255
+ (batch_size, 1, num_steps, state_dim), # ego past trajectory
256
+ (batch_size, 1, num_steps_future, state_dim)] # ego future trajectory
257
+ batch_idx : batch_idx to be used by PyTorch-Lightning
258
+
259
+ Returns:
260
+ dict: dict of outputs containing loss
261
+ """
262
+ x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
263
+ risk_level = repeat(
264
+ self._risk_sampler.sample(x.shape[0], x.device),
265
+ "b -> b num_agents",
266
+ num_agents=x.shape[1],
267
+ )
268
+ loss, log_dict = self._get_loss(
269
+ x=x,
270
+ mask_x=mask_x,
271
+ map=map,
272
+ mask_map=mask_map,
273
+ y=y,
274
+ mask_y=mask_y,
275
+ mask_loss=mask_loss,
276
+ offset=offset,
277
+ risk_level=risk_level,
278
+ x_ego=x_ego,
279
+ y_ego=y_ego,
280
+ )
281
+ if isinstance(loss, tuple):
282
+ loss = sum(loss)
283
+ self.log_with_prefix(log_dict, prefix="train", on_step=True, on_epoch=False)
284
+
285
+ return {"loss": loss}
286
+
287
+ def training_epoch_end(self, outputs: List[dict]) -> None:
288
+ """Called at the end of the training epoch with the outputs of all training steps
289
+
290
+ Args:
291
+ outputs: list of outputs of all training steps in the current epoch
292
+ """
293
+ if self.use_risk_constraint:
294
+ if (
295
+ self.model.training_mode == "bias"
296
+ and (self.trainer.current_epoch + 1)
297
+ % self.risk_constraint_update_every_n_epoch
298
+ == 0
299
+ ):
300
+ self.risk_weight_ratio *= self.risk_constraint_weight_update_factor
301
+ if self.risk_weight_ratio < self.risk_constraint_weight_maximum:
302
+ sum_weight = self.risk_weight + self.kl_weight
303
+ self.risk_weight = (
304
+ sum_weight
305
+ * self.risk_weight_ratio
306
+ / (1 + self.risk_weight_ratio)
307
+ )
308
+ self.kl_weight = sum_weight / (1 + self.risk_weight_ratio)
309
+ # self.risk_weight *= self.risk_constraint_weight_update_factor
310
+ # if self.risk_weight > self.risk_constraint_weight_maximum:
311
+ # self.risk_weight = self.risk_constraint_weight_maximum
312
+
313
+ def _get_risk_tensor(
314
+ self,
315
+ batch_size: int,
316
+ num_agents: int,
317
+ device: torch.device,
318
+ risk_level: Optional[torch.Tensor] = None,
319
+ ):
320
+ """This function is used to reformat different possible formattings of risk_level input arguments into a tensor of shape (batch_size).
321
+ If given a tensor the same tensor is returned.
322
+ If given a float value, a tensor of this value is returned.
323
+ If given None, a tensor filled with random samples is returned.
324
+
325
+ Args:
326
+ batch_size : desired batch size
327
+ device : device on which we want to store risk
328
+ risk_level : The risk level as a tensor, a float value or None
329
+
330
+ Returns:
331
+ _type_: _description_
332
+ """
333
+ if risk_level is not None:
334
+ if isinstance(risk_level, float):
335
+ risk_level = (
336
+ torch.ones(batch_size, num_agents, device=device) * risk_level
337
+ )
338
+ else:
339
+ risk_level = risk_level.to(device)
340
+ else:
341
+ risk_level = None
342
+
343
+ return risk_level
344
+
345
+ def validation_step(
346
+ self,
347
+ batch: Tuple[
348
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
349
+ ],
350
+ batch_idx: int,
351
+ risk_level: float = 1.0,
352
+ ) -> dict:
353
+ """Validation step definition for PyTorch-Lightning
354
+
355
+ Args:
356
+ batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
357
+ (batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
358
+ (batch_size, num_agents, num_steps_future, state_dim), # future trajectory
359
+ (batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data
360
+ (batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted
361
+ (batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
362
+ (batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
363
+ (batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
364
+ (batch_size, 1, num_steps, state_dim), # ego past trajectory
365
+ (batch_size, 1, num_steps_future, state_dim)] # ego future trajectory
366
+ batch_idx : batch_idx to be used by PyTorch-Lightning
367
+ risk_level : optional desired risk level
368
+
369
+ Returns:
370
+ dict: dict of outputs containing loss
371
+ """
372
+ x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
373
+
374
+ risk_level = self._get_risk_tensor(
375
+ x.shape[0], x.shape[1], x.device, risk_level=risk_level
376
+ )
377
+ self.model.eval()
378
+ log_dict_accuracy = self.model.get_prediction_accuracy(
379
+ x=x,
380
+ mask_x=mask_x,
381
+ map=map,
382
+ mask_map=mask_map,
383
+ y=y,
384
+ mask_loss=mask_loss,
385
+ offset=offset,
386
+ x_ego=x_ego,
387
+ y_ego=y_ego,
388
+ unnormalizer=self._unnormalize_trajectory,
389
+ risk_level=risk_level,
390
+ num_samples_min_fde=self.num_samples_min_fde,
391
+ )
392
+
393
+ loss, log_dict_loss = self._get_loss(
394
+ x=x,
395
+ mask_x=mask_x,
396
+ map=map,
397
+ mask_map=mask_map,
398
+ y=y,
399
+ mask_y=mask_y,
400
+ mask_loss=mask_loss,
401
+ offset=offset,
402
+ risk_level=risk_level,
403
+ x_ego=x_ego,
404
+ y_ego=y_ego,
405
+ )
406
+
407
+ if isinstance(loss, tuple):
408
+ loss = sum(loss)
409
+
410
+ self.log_with_prefix(
411
+ dict(log_dict_accuracy, **log_dict_loss),
412
+ prefix="val",
413
+ on_step=False,
414
+ on_epoch=True,
415
+ )
416
+ self.model.train()
417
+ return {"loss": loss}
418
+
419
+ def test_step(
420
+ self,
421
+ batch: Tuple[
422
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
423
+ ],
424
+ batch_idx: int,
425
+ risk_level: Optional[torch.Tensor] = None,
426
+ ) -> dict:
427
+ """Test step definition for PyTorch-Lightning
428
+
429
+ Args:
430
+ batch : [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
431
+ (batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
432
+ (batch_size, num_agents, num_steps_future, state_dim), # future trajectory
433
+ (batch_size, num_agents, num_steps_future), # mask future False where future trajectories are padding data
434
+ (batch_size, num_agents, num_steps_future), # mask loss False where future trajectories are not to be predicted
435
+ (batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
436
+ (batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
437
+ (batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
438
+ (batch_size, 1, num_steps, state_dim), # ego past trajectory
439
+ (batch_size, 1, num_steps_future, state_dim)] # ego future trajectory
440
+ batch_idx : batch_idx to be used by PyTorch-Lightning
441
+ risk_level : optional desired risk level
442
+
443
+ Returns:
444
+ dict: dict of outputs containing loss
445
+ """
446
+ x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
447
+ risk_level = self._get_risk_tensor(
448
+ x.shape[0], x.shape[1], x.device, risk_level=risk_level
449
+ )
450
+ loss, log_dict = self._get_loss(
451
+ x=x,
452
+ mask_x=mask_x,
453
+ map=map,
454
+ mask_map=mask_map,
455
+ y=y,
456
+ mask_y=mask_y,
457
+ mask_loss=mask_loss,
458
+ offset=offset,
459
+ risk_level=risk_level,
460
+ x_ego=x_ego,
461
+ y_ego=y_ego,
462
+ )
463
+ if isinstance(loss, tuple):
464
+ loss = sum(loss)
465
+ self.log_with_prefix(log_dict, prefix="test", on_step=False, on_epoch=True)
466
+ return {"loss": loss}
467
+
468
+ def predict_step(
469
+ self,
470
+ batch: Tuple[torch.Tensor, torch.Tensor],
471
+ batch_idx: int = 0,
472
+ risk_level: Optional[torch.Tensor] = None,
473
+ n_samples: int = 0,
474
+ return_weights: bool = False,
475
+ ) -> torch.Tensor:
476
+ """Predict step definition for PyTorch-Lightning
477
+
478
+ Args:
479
+ batch: [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
480
+ (batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
481
+ (batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
482
+ (batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
483
+ (batch_size, num_agents, state_dim), # position offset of all agents relative to ego at present time
484
+ (batch_size, 1, num_steps, state_dim), # past trajectory of the ego agent in the scene
485
+ (batch_size, 1, num_steps_future, state_dim),] # future trajectory of the ego agent in the scene
486
+ batch_idx : batch_idx to be used by PyTorch-Lightning (unused here)
487
+ risk_level : optional desired risk level
488
+ n_samples: Number of samples to predict per agent
489
+ With value of 0 does not include the `n_samples` dim in the output.
490
+ return_weights: If True, also returns the sample weights
491
+
492
+ Returns:
493
+ (batch_size, (n_samples), num_steps_future, state_dim) tensor
494
+ """
495
+ x, mask_x, map, mask_map, offset, x_ego, y_ego = batch
496
+ risk_level = self._get_risk_tensor(
497
+ batch_size=x.shape[0],
498
+ num_agents=x.shape[1],
499
+ device=x.device,
500
+ risk_level=risk_level,
501
+ )
502
+ y_sampled, weights, _ = self.model(
503
+ x,
504
+ mask_x,
505
+ map,
506
+ mask_map,
507
+ offset=offset,
508
+ x_ego=x_ego,
509
+ y_ego=y_ego,
510
+ risk_level=risk_level,
511
+ n_samples=n_samples,
512
+ )
513
+ predict_sampled = self._unnormalize_trajectory(y_sampled, offset)
514
+ if return_weights:
515
+ return predict_sampled, weights
516
+ else:
517
+ return predict_sampled
518
+
519
+ def predict_loop_once(
520
+ self,
521
+ batch: Tuple[torch.Tensor, torch.Tensor],
522
+ batch_idx: int = 0,
523
+ risk_level: Optional[torch.Tensor] = None,
524
+ ) -> torch.Tensor:
525
+ """Predict with refinment:
526
+ A first prediction is done as in predict_step, however instead of unnormalize and return it,
527
+ it is fed to the encoder that wast trained to encode past and ground truth future.
528
+ Then the decoder is used again but its latent input sample is biased by the encoder
529
+ instead of being a sample of the prior distribution.
530
+ Then as in predict_step the result is unnormalized and returned.
531
+
532
+ Args:
533
+ batch: [(batch_size, num_agents, num_steps, state_dim), # past trajectories of all agents in the scene
534
+ (batch_size, num_agents, num_steps), # mask past False where past trajectories are padding data
535
+ (batch_size, num_objects, object_seq_len, state_dim), # map object sequences in the scene
536
+ (batch_size, num_objects, object_seq_len), # mask map False where map objects are padding data
537
+ (batch_size, num_agents, state_dim),] # position offset of all agents relative to ego at present time
538
+ batch_idx : batch_idx to be used by PyTorch-Lightning (Unused here). Defaults to 0.
539
+ risk_level : optional desired risk level
540
+
541
+ Returns:
542
+ torch.Tensor: (batch_size, num_steps_future, state_dim) tensor
543
+ """
544
+ x, mask_x, map, mask_map, offset = batch
545
+ risk_level = self._get_risk_tensor(
546
+ x.shape[0], x.shape[1], x.device, risk_level=risk_level
547
+ )
548
+ y_sampled, _ = self.model(
549
+ x,
550
+ mask_x,
551
+ map,
552
+ mask_map,
553
+ offset=offset,
554
+ risk_level=risk_level,
555
+ )
556
+ mask_y = repeat(mask_x.any(-1), "b a -> b a f", f=y_sampled.shape[-2])
557
+ y_sampled, _ = self.model(
558
+ x,
559
+ mask_x,
560
+ map,
561
+ mask_map,
562
+ y_sampled,
563
+ mask_y,
564
+ offset=offset,
565
+ risk_level=risk_level,
566
+ )
567
+ predict_sampled = self._unnormalize_trajectory(y_sampled, offset=offset)
568
+ return predict_sampled
risk_biased/scene_dataset/__init__.py ADDED
File without changes
risk_biased/scene_dataset/loaders.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import DataLoader, TensorDataset
6
+
7
+
8
+ class SceneDataLoaders:
9
+ """
10
+ This class loads a scene dataset and pre-process it (normalization, unnormalization)
11
+
12
+ Args:
13
+ state_dim : dimension of the observed state (2 for x,y position observation)
14
+ num_steps : number of observed steps
15
+ num_steps_future : number of steps in the future
16
+ batch_size: set data loader with this batch size
17
+ data_train: training dataset
18
+ data_val: validation dataset
19
+ data_test: test dataset
20
+ num_workers: number of workers to use for data loading
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ state_dim: int,
26
+ num_steps: int,
27
+ num_steps_future: int,
28
+ batch_size: int,
29
+ data_train: torch.Tensor,
30
+ data_val: torch.Tensor,
31
+ data_test: torch.Tensor,
32
+ num_workers: int = 0,
33
+ ):
34
+ self._batch_size = batch_size
35
+ self._num_workers = num_workers
36
+ self._state_dim = state_dim
37
+ self._num_steps = num_steps
38
+ self._num_steps_future = num_steps_future
39
+
40
+ self._setup_datasets(data_train, data_val, data_test)
41
+
42
+ def train_dataloader(self, shuffle=True, drop_last=True) -> DataLoader:
43
+ """Setup and return training DataLoader
44
+
45
+ Returns:
46
+ DataLoader: training DataLoader
47
+ """
48
+ data_size = self._data_train_past.shape[0]
49
+ # This is a didactic data loader that only defines minimalistic inputs.
50
+ # This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
51
+ train_loader = DataLoader(
52
+ dataset=TensorDataset(
53
+ self._data_train_past,
54
+ torch.ones_like(self._data_train_past[..., 0]), # Mask past
55
+ self._data_train_fut,
56
+ torch.ones_like(self._data_train_fut[..., 0]), # Mask fut
57
+ torch.ones_like(self._data_train_fut[..., 0]), # Mask loss
58
+ torch.empty(
59
+ data_size, 1, 0, 0, device=self._data_train_past.device
60
+ ), # Map
61
+ torch.empty(
62
+ data_size, 1, 0, device=self._data_train_past.device
63
+ ), # Mask map
64
+ self._offset_train,
65
+ self._data_train_ego_past,
66
+ self._data_train_ego_fut,
67
+ ),
68
+ batch_size=self._batch_size,
69
+ shuffle=shuffle,
70
+ drop_last=drop_last,
71
+ num_workers=self._num_workers,
72
+ )
73
+ return train_loader
74
+
75
+ def val_dataloader(self, shuffle=False, drop_last=False) -> DataLoader:
76
+ """Setup and return validation DataLoader
77
+
78
+ Returns:
79
+ DataLoader: validation DataLoader
80
+ """
81
+ data_size = self._data_val_past.shape[0]
82
+ # This is a didactic data loader that only defines minimalistic inputs.
83
+ # This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
84
+ val_loader = DataLoader(
85
+ dataset=TensorDataset(
86
+ self._data_val_past,
87
+ torch.ones_like(self._data_val_past[..., 0]), # Mask past
88
+ self._data_val_fut,
89
+ torch.ones_like(self._data_val_fut[..., 0]), # Mask fut
90
+ torch.ones_like(self._data_val_fut[..., 0]), # Mask loss
91
+ torch.zeros(
92
+ data_size, 1, 0, 0, device=self._data_val_past.device
93
+ ), # Map
94
+ torch.ones(
95
+ data_size, 1, 0, device=self._data_val_past.device
96
+ ), # Mask map
97
+ self._offset_val,
98
+ self._data_val_ego_past,
99
+ self._data_val_ego_fut,
100
+ ),
101
+ batch_size=self._batch_size,
102
+ shuffle=shuffle,
103
+ drop_last=drop_last,
104
+ num_workers=self._num_workers,
105
+ )
106
+ return val_loader
107
+
108
+ def test_dataloader(self) -> DataLoader:
109
+ """Setup and return test DataLoader
110
+
111
+ Returns:
112
+ DataLoader: test DataLoader
113
+ """
114
+ data_size = self._data_test_past.shape[0]
115
+ # This is a didactic data loader that only defines minimalistic inputs.
116
+ # This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
117
+ test_loader = DataLoader(
118
+ dataset=TensorDataset(
119
+ self._data_test_past,
120
+ torch.ones_like(self._data_test_past[..., 0]), # Mask
121
+ torch.zeros(
122
+ data_size, 0, 1, 0, device=self._data_test_past.device
123
+ ), # Map
124
+ torch.ones(
125
+ data_size, 0, 1, device=self._data_test_past.device
126
+ ), # Mask map
127
+ self._offset_test,
128
+ self._data_test_ego_past,
129
+ self._data_test_ego_fut,
130
+ ),
131
+ batch_size=self._batch_size,
132
+ shuffle=False,
133
+ num_workers=self._num_workers,
134
+ )
135
+ return test_loader
136
+
137
+ def _setup_datasets(
138
+ self, data_train: torch.Tensor, data_val: torch.Tensor, data_test: torch.Tensor
139
+ ):
140
+ """Setup datasets: normalize and split into past future
141
+ Args:
142
+ data_train: training dataset
143
+ data_val: validation dataset
144
+ data_test: test dataset
145
+ """
146
+ data_train, data_train_ego = data_train[0], data_train[1]
147
+ data_val, data_val_ego = data_val[0], data_val[1]
148
+ data_test, data_test_ego = data_test[0], data_test[1]
149
+
150
+ data_train, self._offset_train = self.normalize_trajectory(data_train)
151
+ data_val, self._offset_val = self.normalize_trajectory(data_val)
152
+ data_test, self._offset_test = self.normalize_trajectory(data_test)
153
+ # This is a didactic data loader that only defines minimalistic inputs.
154
+ # An extra dimension is added to account for the number of agents in the scene.
155
+ # In this minimal input there is only one but the model using the data expects any number of agents.
156
+ self._data_train_past, self._data_train_fut = self.split_trajectory(data_train)
157
+ self._data_val_past, self._data_val_fut = self.split_trajectory(data_val)
158
+ self._data_test_past, self._data_test_fut = self.split_trajectory(data_test)
159
+
160
+ self._data_train_ego_past, self._data_train_ego_fut = self.split_trajectory(
161
+ data_train_ego
162
+ )
163
+ self._data_val_ego_past, self._data_val_ego_fut = self.split_trajectory(
164
+ data_val_ego
165
+ )
166
+ self._data_test_ego_past, self._data_test_ego_fut = self.split_trajectory(
167
+ data_test_ego
168
+ )
169
+
170
+ def split_trajectory(
171
+ self, input: torch.Tensor
172
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ """Split input trajectory into history and future
174
+
175
+ Args:
176
+ input : (batch_size, (n_agents), num_steps + num_steps_future, state_dim) tensor of
177
+ entire trajectory [x, y]
178
+
179
+ Returns:
180
+ Tuple of history and future trajectories
181
+ """
182
+ assert (
183
+ input.shape[-2] == self._num_steps + self._num_steps_future
184
+ ), "trajectory length ({}) does not match the expected length".format(
185
+ input.shape[-2]
186
+ )
187
+ assert (
188
+ input.shape[-1] == self._state_dim
189
+ ), "state dimension ({}) does no match the expected dimension".format(
190
+ input.shape[-1]
191
+ )
192
+
193
+ input_history, input_future = torch.split(
194
+ input, [self._num_steps, self._num_steps_future], dim=-2
195
+ )
196
+ return input_history, input_future
197
+
198
+ @staticmethod
199
+ def normalize_trajectory(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
200
+ """Normalize input trajectory by subtracting initial state
201
+
202
+ Args:
203
+ input : (some_shape, n_agents, num_steps + num_steps_future, state_dim) tensor of
204
+ entire trajectory [x, y], or (some_shape, num_steps, state_dim) tensor of history x
205
+
206
+ Returns:
207
+ Tuple of (normalized_trajectory, offset), where
208
+ normalized_trajectory has the same dimension as the input and offset is a
209
+ (some_shape, state_dim) tensor corresponding to the initial state
210
+ """
211
+ offset = input[..., 0, :].clone()
212
+
213
+ return input - offset.unsqueeze(-2), offset
214
+
215
+ @staticmethod
216
+ def unnormalize_trajectory(
217
+ input: torch.Tensor, offset: torch.Tensor
218
+ ) -> torch.Tensor:
219
+ """Unnormalize trajectory by adding offset to input
220
+
221
+ Args:
222
+ input : (some_shape, (n_sample), num_steps_future, state_dim) tensor of future
223
+ trajectory y
224
+ offset : (some_shape, 2 or 4 or 5) tensor of offset to add to y
225
+
226
+ Returns:
227
+ Unnormalized trajectory that has the same size as input
228
+ """
229
+ offset_dim = offset.shape[-1]
230
+ assert input.shape[-1] >= offset_dim
231
+ input_clone = input.clone()
232
+ if offset.ndim == 2:
233
+ batch_size, _ = offset.shape
234
+ assert input_clone.shape[0] == batch_size
235
+
236
+ input_clone[..., :offset_dim] = input_clone[
237
+ ..., :offset_dim
238
+ ] + offset.reshape(
239
+ [batch_size, *([1] * (input_clone.ndim - 2)), offset_dim]
240
+ )
241
+ elif offset.ndim == 3:
242
+ batch_size, num_agents, _ = offset.shape
243
+ assert input_clone.shape[0] == batch_size
244
+ assert input_clone.shape[1] == num_agents
245
+
246
+ input_clone[..., :offset_dim] = input_clone[
247
+ ..., :offset_dim
248
+ ] + offset.reshape(
249
+ [batch_size, num_agents, *([1] * (input_clone.ndim - 3)), offset_dim]
250
+ )
251
+
252
+ return input_clone
risk_biased/scene_dataset/pedestrian.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import Tensor
4
+ from typing import Union
5
+
6
+
7
+ class RandomPedestrians:
8
+ """
9
+ Batched random pedestrians.
10
+ There are two types of pedestrians, slow and fast ones.
11
+ Each pedestrian type is walking mainly at its constant favored speed but at each time step there is a probability that it changes its pace.
12
+
13
+ Args:
14
+ batch_size: int number of scenes in the batch
15
+ dt: float time step to use in the trajectory sequence
16
+ fast_speed: float fast walking speed for the random pedestrian in meters/seconds
17
+ slow_speed: float slow walking speed for the random pedestrian in meters/seconds
18
+ p_change_pace: float probability that a slow (resp. fast) pedestrian walk at fast_speed (resp. slow_speed) at each time step
19
+ proportion_fast: float proportion of the pedestrians that are mainly walking at fast_speed
20
+ is_torch: bool set to True to produce Tensor batches and to False to produce numpy arrays
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ batch_size: int,
26
+ dt: float = 0.1,
27
+ fast_speed: float = 2,
28
+ slow_speed: float = 1,
29
+ p_change_pace: float = 0.1,
30
+ proportion_fast: float = 0.5,
31
+ is_torch: bool = False,
32
+ ) -> None:
33
+
34
+ self.is_torch = is_torch
35
+ self.fast_speed: float = fast_speed
36
+ self.slow_speed: float = slow_speed
37
+ self.dt: float = dt
38
+ self.p_change_pace: float = p_change_pace
39
+ self.batch_size: int = batch_size
40
+
41
+ self.propotion_fast: float = proportion_fast
42
+ if self.is_torch:
43
+ self.is_fast_type: Tensor = torch.from_numpy(
44
+ np.random.binomial(1, self.propotion_fast, [batch_size, 1, 1]).astype(
45
+ "float32"
46
+ )
47
+ )
48
+ self.is_currently_fast: Tensor = self.is_fast_type.clone()
49
+ self.initial_position: Tensor = torch.zeros([batch_size, 1, 2])
50
+ self.position: Tensor = self.initial_position.clone()
51
+ self._angle: Tensor = (2 * torch.rand(batch_size, 1) - 1) * np.pi
52
+ self.unit_velocity: Tensor = torch.stack(
53
+ (torch.cos(self._angle), torch.sin(self._angle)), -1
54
+ )
55
+ else:
56
+ self.is_fast_type: np.ndarray = np.random.binomial(
57
+ 1, self.propotion_fast, [batch_size, 1, 1]
58
+ )
59
+ self.is_currently_fast: np.ndarray = self.is_fast_type.copy()
60
+ self.initial_position: np.ndarray = np.zeros([batch_size, 1, 2])
61
+ self.position: np.ndarray = self.initial_position.copy()
62
+ self._angle: np.ndarray = np.random.uniform(-np.pi, np.pi, (batch_size, 1))
63
+ self.unit_velocity: np.ndarray = np.stack(
64
+ (np.cos(self._angle), np.sin(self._angle)), -1
65
+ )
66
+
67
+ @property
68
+ def angle(self):
69
+ return self._angle
70
+
71
+ @angle.setter
72
+ def angle(self, angle: Union[np.ndarray, torch.Tensor]):
73
+ assert self.batch_size == angle.shape[0]
74
+ if self.is_torch:
75
+ assert isinstance(angle, torch.Tensor)
76
+ self._angle = angle
77
+ self.unit_velocity = torch.stack(
78
+ (torch.cos(self._angle), torch.sin(self._angle)), -1
79
+ )
80
+ else:
81
+ assert isinstance(angle, np.ndarray)
82
+ self._angle = angle
83
+ self.unit_velocity = np.stack(
84
+ (np.cos(self._angle), np.sin(self._angle)), -1
85
+ )
86
+
87
+ def step(self) -> None:
88
+ """
89
+ Forward one time step, update the speed selection and the current position.
90
+ """
91
+ self.update_speed()
92
+ self.update_position()
93
+
94
+ def update_speed(self) -> None:
95
+ """
96
+ Update the speed as a random selection between favored speed and the other speed with probability self.p_change_pace.
97
+ """
98
+ if self.is_torch:
99
+ do_flip = (
100
+ torch.from_numpy(
101
+ np.random.binomial(1, self.p_change_pace, self.batch_size).astype(
102
+ "float32"
103
+ )
104
+ )
105
+ == 1
106
+ )
107
+ self.is_currently_fast = self.is_fast_type.clone()
108
+ else:
109
+ do_flip = np.random.binomial(1, self.p_change_pace, self.batch_size) == 1
110
+ self.is_currently_fast = self.is_fast_type.copy()
111
+ self.is_currently_fast[do_flip] = 1 - self.is_fast_type[do_flip]
112
+
113
+ def update_position(self) -> None:
114
+ """
115
+ Update the position as current position + time_step*speed*(cos(angle), sin(angle))
116
+ """
117
+ self.position += (
118
+ self.dt
119
+ * (
120
+ self.slow_speed
121
+ + (self.fast_speed - self.slow_speed) * self.is_currently_fast
122
+ )
123
+ * self.unit_velocity
124
+ )
125
+
126
+ def travel_distance(self) -> Union[np.ndarray, Tensor]:
127
+ """
128
+ Return the travel distance between initial position and current position.
129
+ """
130
+ if self.is_torch:
131
+ return torch.sqrt(
132
+ torch.sum(torch.square(self.position - self.initial_position), -1)
133
+ )
134
+ else:
135
+ return np.sqrt(np.sum(np.square(self.position - self.initial_position), -1))
136
+
137
+ def get_final_position(self, time: float) -> Union[np.ndarray, Tensor]:
138
+ """
139
+ Return a sample of pedestrian final positions using their speed distribution.
140
+ (This is stochastic, different samples will produce different results).
141
+ Args:
142
+ time: The final time at which to get the position
143
+ Returns:
144
+ The batch of final positions
145
+ """
146
+ num_steps = int(round(time / self.dt))
147
+ if self.is_torch:
148
+ cumulative_change_state = torch.from_numpy(
149
+ np.random.binomial(
150
+ num_steps, self.p_change_pace, [self.batch_size, 1, 1]
151
+ ).astype("float32")
152
+ )
153
+ else:
154
+ cumulative_change_state = np.random.binomial(
155
+ num_steps, self.p_change_pace, [self.batch_size, 1, 1]
156
+ )
157
+
158
+ num_fast_steps = (
159
+ num_steps - 2 * cumulative_change_state
160
+ ) * self.is_fast_type + cumulative_change_state
161
+
162
+ return self.position + self.unit_velocity * self.dt * (
163
+ self.slow_speed * num_steps
164
+ + (self.fast_speed - self.slow_speed) * num_fast_steps
165
+ )
risk_biased/scene_dataset/scene.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import os
3
+ from typing import Union, List, Optional
4
+ import warnings
5
+ import copy
6
+
7
+ from mmcv import Config
8
+ import numpy as np
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.utils.data import Dataset
12
+
13
+ from risk_biased.scene_dataset.pedestrian import RandomPedestrians
14
+ from risk_biased.utils.torch_utils import torch_linspace
15
+
16
+
17
+ @dataclass
18
+ class RandomSceneParams:
19
+ """Dataclass that defines all the listed parameters that are necessary for a RandomScene object
20
+
21
+ Args:
22
+ batch_size: number of scenes in the batch
23
+ time_scene: time length of the scene in seconds
24
+ sample_times: list of times to get the positions
25
+ ego_ref_speed: constant reference speed of the ego vehicle in meters/seconds
26
+ ego_speed_init_low: lowest initial speed of the ego vehicle in meters/seconds
27
+ ego_speed_init_high: higest initial speed of the ego vehicle in meters/seconds
28
+ ego_acceleration_mean_low: lowest mean acceleration of the ego vehicle in m/s^2
29
+ ego_acceleration_mean_high: highest mean acceleration of the ego vehicle in m/s^2
30
+ ego_acceleration_std: std for acceleration of the ego vehicle in m/s^2
31
+ ego_length: length of the ego vehicle in meters
32
+ ego_width: width of the ego vehicle in meters
33
+ dt: time step to use in the trajectory sequence
34
+ fast_speed: fast walking speed for the random pedestrian in meters/seconds
35
+ slow_speed: slow walking speed for the random pedestrian in meters/seconds
36
+ p_change_pace: probability that a slow (resp. fast) pedestrian walk at fast_speed (resp. slow_speed) at each time step
37
+ proportion_fast: proportion of the pedestrians that are mainly walking at fast_speed
38
+ perception_noise_std: standard deviation of the gaussian noise that is affecting the position observations
39
+ """
40
+
41
+ batch_size: int
42
+ time_scene: float
43
+ sample_times: list
44
+ ego_ref_speed: float
45
+ ego_speed_init_low: float
46
+ ego_speed_init_high: float
47
+ ego_acceleration_mean_low: float
48
+ ego_acceleration_mean_high: float
49
+ ego_acceleration_std: float
50
+ ego_length: float
51
+ ego_width: float
52
+ dt: float
53
+ fast_speed: float
54
+ slow_speed: float
55
+ p_change_pace: float
56
+ proportion_fast: float
57
+ perception_noise_std: float
58
+
59
+ @staticmethod
60
+ def from_config(cfg: Config):
61
+ return RandomSceneParams(
62
+ batch_size=cfg.batch_size,
63
+ sample_times=cfg.sample_times,
64
+ time_scene=cfg.time_scene,
65
+ ego_ref_speed=cfg.ego_ref_speed,
66
+ ego_speed_init_low=cfg.ego_speed_init_low,
67
+ ego_speed_init_high=cfg.ego_speed_init_high,
68
+ ego_acceleration_mean_low=cfg.ego_acceleration_mean_low,
69
+ ego_acceleration_mean_high=cfg.ego_acceleration_mean_high,
70
+ ego_acceleration_std=cfg.ego_acceleration_std,
71
+ ego_length=cfg.ego_length,
72
+ ego_width=cfg.ego_width,
73
+ dt=cfg.dt,
74
+ fast_speed=cfg.fast_speed,
75
+ slow_speed=cfg.slow_speed,
76
+ p_change_pace=cfg.p_change_pace,
77
+ proportion_fast=cfg.proportion_fast,
78
+ perception_noise_std=cfg.perception_noise_std,
79
+ )
80
+
81
+
82
+ class RandomScene:
83
+ """
84
+ Batched scenes with one vehicle at constant velocity and one random pedestrian. Utility functions to draw the scene and compute risk factors (time to collision etc...)
85
+
86
+ Args:
87
+ params: dataclass containing the necessary parameters
88
+ is_torch: set to True to produce Tensor batches and to False to produce numpy arrays
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ params: RandomSceneParams,
94
+ is_torch: bool = False,
95
+ ) -> None:
96
+
97
+ self._is_torch = is_torch
98
+ self._batch_size = params.batch_size
99
+ self._fast_speed = params.fast_speed
100
+ self._slow_speed = params.slow_speed
101
+ self._p_change_pace = params.p_change_pace
102
+ self._proportion_fast = params.proportion_fast
103
+ self.dt = params.dt
104
+ self.sample_times = params.sample_times
105
+ self.ego_ref_speed = params.ego_ref_speed
106
+ self._ego_speed_init_low = params.ego_speed_init_low
107
+ self._ego_speed_init_high = params.ego_speed_init_high
108
+ self._ego_acceleration_mean_low = params.ego_acceleration_mean_low
109
+ self._ego_acceleration_mean_high = params.ego_acceleration_mean_high
110
+ self._ego_acceleration_std = params.ego_acceleration_std
111
+ self.perception_noise_std = params.perception_noise_std
112
+ self.road_length = (
113
+ params.ego_ref_speed + params.fast_speed
114
+ ) * params.time_scene
115
+ self.time_scene = params.time_scene
116
+ self.lane_width = 3
117
+ self.sidewalks_width = 1.5
118
+ self.road_width = 2 * self.lane_width + 2 * self.sidewalks_width
119
+ self.bottom = -self.lane_width / 2 - self.sidewalks_width
120
+ self.top = 3 * self.lane_width / 2 + self.sidewalks_width
121
+ self.ego_width = 1.75
122
+ self.ego_length = 4
123
+ self.current_time = 0
124
+
125
+ if self._is_torch:
126
+ pedestrians_x = (
127
+ torch.rand(params.batch_size, 1)
128
+ * (self.road_length - self.ego_length / 2)
129
+ + self.ego_length / 2
130
+ )
131
+ pedestrians_y = (
132
+ torch.rand(params.batch_size, 1) * (self.top - self.bottom)
133
+ + self.bottom
134
+ )
135
+ self._pedestrians_positions = torch.stack(
136
+ (pedestrians_x, pedestrians_y), -1
137
+ )
138
+ else:
139
+ pedestrians_x = np.random.uniform(
140
+ low=self.ego_length / 2,
141
+ high=self.road_length,
142
+ size=(params.batch_size, 1),
143
+ )
144
+ pedestrians_y = np.random.uniform(
145
+ low=self.bottom, high=self.top, size=(params.batch_size, 1)
146
+ )
147
+ self._pedestrians_positions = np.stack((pedestrians_x, pedestrians_y), -1)
148
+
149
+ self.pedestrians = RandomPedestrians(
150
+ batch_size=self._batch_size,
151
+ dt=self.dt,
152
+ fast_speed=self._fast_speed,
153
+ slow_speed=self._slow_speed,
154
+ p_change_pace=self._p_change_pace,
155
+ proportion_fast=self._proportion_fast,
156
+ is_torch=self._is_torch,
157
+ )
158
+ self._set_pedestrians()
159
+
160
+ @property
161
+ def pedestrians_positions(self):
162
+ # relative_positions = self._pedestrians_positions/[[(self.road_length - self.ego_length / 2), (self.top - self.bottom)]] - [[self.ego_length / 2, self.bottom]]
163
+ return self._pedestrians_positions
164
+
165
+ def set_pedestrians_states(
166
+ self,
167
+ relative_pedestrians_positions: Union[torch.Tensor, np.ndarray],
168
+ pedestrians_angles: Optional[Union[torch.Tensor, np.ndarray]] = None,
169
+ ):
170
+ """Force pedestrian initial states
171
+
172
+ Args:
173
+ relative_pedestrians_positions: Relative positions in the scene as percentage distance from left to right and from bottom to top
174
+ pedestrians_angles: Pedestrian heading angles in radiants
175
+ """
176
+ if self._is_torch:
177
+ assert isinstance(relative_pedestrians_positions, torch.Tensor)
178
+ else:
179
+ assert isinstance(relative_pedestrians_positions, np.ndarray)
180
+
181
+ self._batch_size = relative_pedestrians_positions.shape[0]
182
+ if (0 > relative_pedestrians_positions).any() or (
183
+ relative_pedestrians_positions > 1
184
+ ).any():
185
+ warnings.warn(
186
+ "Some of the given pedestrian initial positions are outside of the road range"
187
+ )
188
+ center_y = (self.top - self.bottom) * relative_pedestrians_positions[
189
+ :, :, 1
190
+ ] + self.bottom
191
+ center_x = (
192
+ self.road_length - self.ego_length / 2
193
+ ) * relative_pedestrians_positions[:, :, 0] + self.ego_length / 2
194
+ if self._is_torch:
195
+ pedestrians_positions = torch.stack([center_x, center_y], -1)
196
+ else:
197
+ pedestrians_positions = np.stack([center_x, center_y], -1)
198
+
199
+ self.pedestrians = RandomPedestrians(
200
+ batch_size=self._batch_size,
201
+ dt=self.dt,
202
+ fast_speed=self._fast_speed,
203
+ slow_speed=self._slow_speed,
204
+ p_change_pace=self._p_change_pace,
205
+ proportion_fast=self._proportion_fast,
206
+ is_torch=self._is_torch,
207
+ )
208
+ self._pedestrians_positions = pedestrians_positions
209
+ if pedestrians_angles is not None:
210
+ self.pedestrians.angle = pedestrians_angles
211
+ self._set_pedestrians()
212
+
213
+ def _set_pedestrians(self):
214
+ self.pedestrians_trajectories = self.sample_pedestrians_trajectories(
215
+ self.sample_times
216
+ )
217
+
218
+ self.final_pedestrians_positions = self.pedestrians_trajectories[:, :, -1]
219
+
220
+ def get_ego_ref_trajectory(self, time_sequence: list):
221
+ """
222
+ Returns only one ego reference trajectory and not a batch because it is always the same.
223
+ Args:
224
+ time_sequence: the time points at which to get the positions
225
+ """
226
+ out = np.array([[[[t * self.ego_ref_speed, 0] for t in time_sequence]]])
227
+ if self._is_torch:
228
+ return torch.from_numpy(out.astype("float32"))
229
+ else:
230
+ return out
231
+
232
+ def get_pedestrians_velocities(self):
233
+ """
234
+ Returns the batch of mean pedestrian velocities between their positions and their final positions.
235
+ """
236
+ return (self.final_pedestrians_positions - self._pedestrians_positions)[
237
+ :, None
238
+ ] / self.time_scene
239
+
240
+ def get_ego_ref_velocity(self):
241
+ """
242
+ Returns the reference ego velocity.
243
+ """
244
+ if self._is_torch:
245
+ return torch.from_numpy(
246
+ np.array([[[[self.ego_ref_speed, 0]]]], dtype="float32")
247
+ )
248
+ else:
249
+ return np.array([[[[self.ego_ref_speed, 0]]]])
250
+
251
+ def get_ego_ref_position(self):
252
+ """
253
+ Returns the current reference ego position (at set time self.current_time)
254
+ """
255
+ if self._is_torch:
256
+ return torch.from_numpy(
257
+ np.array(
258
+ [[[[self.ego_ref_speed * self.current_time, 0]]]], dtype="float32"
259
+ )
260
+ )
261
+ else:
262
+ return np.array([[[[self.ego_ref_speed * self.current_time, 0]]]])
263
+
264
+ def set_current_time(self, time: float):
265
+ """
266
+ Set the current time of the scene.
267
+ Args:
268
+ time : The current time to set. It should be between 0 and 1
269
+ """
270
+ assert 0 <= time <= self.time_scene
271
+ self.current_time = time
272
+
273
+ def sample_ego_velocities(self, time_sequence: list):
274
+ """
275
+ Get ego velocity trajectories following the ego's acceleration distribution and the initial
276
+ velocity distribution.
277
+
278
+ Args:
279
+ time_sequence: a list of time points at which to sample the trajectory positions.
280
+ Returns:
281
+ batch of sequence of velocities of shape (batch_size, 1, len(time_sequence), 2)
282
+ """
283
+ vel_traj = []
284
+ # uniform sampling of acceleration_mean between self._ego_acceleration_mean_low and
285
+ # self._ego_acceleration_mean_high
286
+ acceleration_mean = np.random.rand(self._batch_size, 2) * np.array(
287
+ [
288
+ self._ego_acceleration_mean_high - self._ego_acceleration_mean_low,
289
+ 0.0,
290
+ ]
291
+ ) + np.array([self._ego_acceleration_mean_low, 0.0])
292
+ t_prev = 0
293
+ # uniform sampling of initial velocity between self._ego_speed_init_low and
294
+ # self._ego_speed_init_high
295
+ vel_prev = np.random.rand(self._batch_size, 2) * np.array(
296
+ [self._ego_speed_init_high - self._ego_speed_init_low, 0.0]
297
+ ) + np.array([self._ego_speed_init_low, 0.0])
298
+ for t in time_sequence:
299
+ # integrate accelerations once to get velocities
300
+ acceleration = acceleration_mean + np.random.randn(
301
+ self._batch_size, 2
302
+ ) * np.array([self._ego_acceleration_std, 0.0])
303
+ vel_prev = vel_prev + acceleration * (t - t_prev)
304
+ t_prev = t
305
+ vel_traj.append(vel_prev)
306
+ vel_traj = np.stack(vel_traj, 1)
307
+ if self._is_torch:
308
+ vel_traj = torch.from_numpy(vel_traj.astype("float32"))
309
+ return vel_traj[:, None]
310
+
311
+ def sample_ego_trajectories(self, time_sequence: list):
312
+ """
313
+ Get ego trajectories following the ego's acceleration distribution and the initial velocity
314
+ distribution.
315
+
316
+ Args:
317
+ time_sequence: a list of time points at which to sample the trajectory positions.
318
+ Returns:
319
+ batch of sequence of positions of shape (batch_size, len(time_sequence), 2)
320
+ """
321
+ vel_traj = self.sample_ego_velocities(time_sequence)
322
+ traj = []
323
+ t_prev = 0
324
+ pos_prev = np.array([[0, 0]], dtype="float32")
325
+ if self._is_torch:
326
+ pos_prev = torch.from_numpy(pos_prev)
327
+ for idx, t in enumerate(time_sequence):
328
+ # integrate velocities once to get positions
329
+ vel = vel_traj[:, :, idx, :]
330
+ pos_prev = pos_prev + vel * (t - t_prev)
331
+ t_prev = t
332
+ traj.append(pos_prev)
333
+ if self._is_torch:
334
+ return torch.stack(traj, -2)
335
+ else:
336
+ return np.stack(traj, -2)
337
+
338
+ def sample_pedestrians_trajectories(self, time_sequence: list):
339
+ """
340
+ Produce pedestrian trajectories following the pedestrian behavior distribution
341
+ (it is resampled, the final position will not match self.final_pedestrians_positions)
342
+ Args:
343
+ time_sequence: a list of time points at which to sample the trajectory positions.
344
+ Returns:
345
+ batch of sequence of positions of shape (batch_size, len(time_sequence), 2)
346
+ """
347
+ traj = []
348
+ t_prev = 0
349
+ pos_prev = self.pedestrians_positions
350
+ for t in time_sequence:
351
+ pos_prev = (
352
+ pos_prev
353
+ + self.pedestrians.get_final_position(t - t_prev)
354
+ - self.pedestrians.position
355
+ )
356
+ t_prev = t
357
+ traj.append(pos_prev)
358
+ if self._is_torch:
359
+ traj = torch.stack(traj, 2)
360
+ return traj + torch.randn_like(traj) * self.perception_noise_std
361
+ else:
362
+ traj = np.stack(traj, 2)
363
+ return traj + np.random.randn(*traj.shape) * self.perception_noise_std
364
+
365
+ def get_pedestrians_trajectories(self):
366
+ """
367
+ Returns the batch of pedestrian trajectories sampled every dt.
368
+ """
369
+ return self.pedestrians_trajectories
370
+
371
+ def get_pedestrian_trajectory(self, ind: int, time_sequence: list = None):
372
+ """
373
+ Returns one pedestrian trajectory of index ind sampled at times set in time_sequence.
374
+ Args:
375
+ ind: index of the pedestrian in the batch.
376
+ time_sequence: a list of time points at which to sample the trajectory positions.
377
+ Returns:
378
+ A pedestrian trajectory of shape (len(time_sequence), 2)
379
+ """
380
+ len_traj = len(self.sample_times)
381
+ if self._is_torch:
382
+ ped_traj = torch_linspace(
383
+ self.pedestrians_positions[ind],
384
+ self.final_pedestrians_positions[ind],
385
+ len_traj,
386
+ )
387
+ else:
388
+ ped_traj = np.linspace(
389
+ self.pedestrians_positions[ind],
390
+ self.final_pedestrians_positions[ind],
391
+ len_traj,
392
+ )
393
+
394
+ if time_sequence is not None:
395
+ n_steps = [int(t / self.dt) for t in time_sequence]
396
+ else:
397
+ n_steps = range(int(self.time_scene / self.dt))
398
+ return ped_traj[n_steps]
399
+
400
+
401
+ class SceneDataset(Dataset):
402
+ """
403
+ Dataset of scenes with one vehicle at constant velocity and one random pedestrian.
404
+ The scenes are randomly generated so the distribution can be sampled at each batch or pre-fetched.
405
+
406
+ Args:
407
+ len: int number of scenes per epoch
408
+ params: dataclass defining all the necessary parameters
409
+ pre_fetch: set to True to fetch the whole dataset at initialization
410
+ """
411
+
412
+ def __init__(
413
+ self,
414
+ len: int,
415
+ params: RandomSceneParams,
416
+ pre_fetch: bool = True,
417
+ ) -> None:
418
+ super().__init__()
419
+ self._pre_fetch = pre_fetch
420
+ self._len = len
421
+ self._sample_times = params.sample_times
422
+ self.params = copy.deepcopy(params)
423
+ params.batch_size = len
424
+ if self._pre_fetch:
425
+ self.scene_set = RandomScene(
426
+ params, is_torch=True
427
+ ).sample_pedestrians_trajectories(self._sample_times)
428
+
429
+ def __len__(self) -> int:
430
+ return self._len
431
+
432
+ # This is a hack, get item only returns the index so that the collate_fn can handle making the batch without looping on RandomScene creation.
433
+ def __getitem__(self, index: int) -> Tensor:
434
+ return index
435
+
436
+ def collate_fn(self, index_list: list) -> Tensor:
437
+ if self._pre_fetch:
438
+ return self.scene_set[torch.from_numpy(np.array(index_list))]
439
+ else:
440
+ self.params.batch_size = len(index_list)
441
+ return RandomScene(
442
+ self.params,
443
+ is_torch=True,
444
+ ).sample_pedestrians_trajectories(self._sample_times)
445
+
446
+
447
+ # Call this function to create a dataset as a .npy file that can be loaded as a numpy array with np.load(file_name.npy)
448
+ def save_dataset(file_path: str, size: int, config: Config):
449
+ """
450
+ Save a dataset at file_path using the configuration.
451
+ Args:
452
+ file_path: Where to save the dataset
453
+ size: Number of samples to save
454
+ config: Configuration to use for the dataset generation
455
+ """
456
+ dir_path = os.path.dirname(file_path)
457
+ config_path = os.path.join(dir_path, "config.py")
458
+ config = copy.deepcopy(config)
459
+ config.batch_size = size
460
+ params = RandomSceneParams.from_config(config)
461
+ scene = RandomScene(
462
+ params,
463
+ is_torch=False,
464
+ )
465
+ data_pedestrian = scene.sample_pedestrians_trajectories(config.sample_times)
466
+ data_ego = scene.sample_ego_trajectories(config.sample_times)
467
+ data = np.stack([data_pedestrian, data_ego], 0)
468
+ np.save(file_path, data)
469
+ # Cannot use config.dump here because it is buggy and does not work if config was not loaded from a file.
470
+ with open(config_path, "w", encoding="utf-8") as f:
471
+ f.write(config.pretty_text)
472
+
473
+
474
+ def load_create_dataset(
475
+ config: Config,
476
+ base_dir=None,
477
+ ) -> List:
478
+ """
479
+ Load the dataset described by its config if it exists or create one.
480
+
481
+ Args:
482
+ config: Configuration to use for the dataset
483
+ base_dir: Where to look for the dataset or to save it.
484
+ """
485
+
486
+ if base_dir is None:
487
+ base_dir = os.path.join(
488
+ os.path.dirname(os.path.realpath(__file__)), "..", "..", "data"
489
+ )
490
+ found = False
491
+ dataset_out = []
492
+ i = 0
493
+ dir_path = os.path.join(base_dir, f"scene_dataset_{i:03d}")
494
+ while os.path.exists(dir_path):
495
+ config_path = os.path.join(dir_path, "config.py")
496
+ if os.path.exists(config_path):
497
+ config_check = Config.fromfile(config_path)
498
+ if config_check.dataset_parameters == config.dataset_parameters:
499
+ found = True
500
+ break
501
+ else:
502
+ warnings.warn(
503
+ f"Dataset directory {dir_path} exists but doesn't contain a config file. Cannot use it."
504
+ )
505
+ i += 1
506
+ dir_path = os.path.join(base_dir, f"scene_dataset_{i:03d}")
507
+
508
+ if not found:
509
+ print(f"Dataset not found, creating a new one.")
510
+ os.makedirs(dir_path)
511
+ for dataset in config.datasets:
512
+ dataset_name = f"scene_dataset_{dataset}.npy"
513
+ dataset_path = os.path.join(dir_path, dataset_name)
514
+ save_dataset(dataset_path, config.datasets_sizes[dataset], config)
515
+ if found:
516
+ print(f"Loading existing dataset at {dir_path}.")
517
+
518
+ for dataset in config.datasets:
519
+ dataset_path = os.path.join(dir_path, f"scene_dataset_{dataset}.npy")
520
+ dataset_out.append(torch.from_numpy(np.load(dataset_path).astype("float32")))
521
+
522
+ return dataset_out
risk_biased/scene_dataset/scene_plotter.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ from matplotlib.axes import Axes
5
+ from matplotlib.collections import PatchCollection
6
+ from matplotlib.lines import Line2D
7
+ from matplotlib.patches import Rectangle, Ellipse
8
+ import matplotlib
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+
12
+ from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
13
+
14
+
15
+ class ScenePlotter:
16
+ """
17
+ This class defines plotting functions that takes in a scene and an optional axes to plot road agents and trajectories.
18
+
19
+ Args:
20
+ scene: The scene to use for plotting
21
+ ax: Matplotlib axes in which the drawing is made
22
+ """
23
+
24
+ def __init__(self, scene: RandomScene, ax: Optional[Axes] = None) -> None:
25
+ self.scene = scene
26
+ if ax is None:
27
+ self.ax = plt.subplot()
28
+ else:
29
+ self.ax = ax
30
+ self._sidewalks_boxes = PatchCollection(
31
+ [
32
+ Rectangle(
33
+ xy=[-scene.ego_length, scene.bottom],
34
+ height=scene.sidewalks_width,
35
+ width=scene.road_length + scene.ego_length,
36
+ ),
37
+ Rectangle(
38
+ xy=[-scene.ego_length, 3 * scene.lane_width / 2],
39
+ height=scene.sidewalks_width,
40
+ width=scene.road_length + scene.ego_length,
41
+ ),
42
+ ],
43
+ facecolor="gray",
44
+ alpha=0.3,
45
+ edgecolor="black",
46
+ )
47
+ self._center_line = Line2D(
48
+ [-scene.ego_length / 2, scene.road_length],
49
+ [scene.lane_width / 2, scene.lane_width / 2],
50
+ linewidth=4,
51
+ color="black",
52
+ dashes=[10, 5],
53
+ )
54
+
55
+ self._set_agent_patches()
56
+ self._set_agent_paths()
57
+ self.ax.set_aspect("equal")
58
+
59
+ def _set_current_time(self, time: float):
60
+ """
61
+ Set the current time to draw the agents at the proper time along the trajectory.
62
+
63
+ Args:
64
+ time: the present time in second
65
+ """
66
+ self.scene.set_current_time(time)
67
+ self._set_agent_patches()
68
+
69
+ def _set_agent_paths(self):
70
+ """
71
+ Defines path as lines.
72
+ """
73
+ self._ego_path = Line2D(
74
+ [0, self.scene.ego_ref_speed * self.scene.time_scene],
75
+ [0, 0],
76
+ linewidth=2,
77
+ color="red",
78
+ dashes=[4, 4],
79
+ alpha=0.3,
80
+ )
81
+
82
+ self._pedestrian_path = [
83
+ [
84
+ Line2D(
85
+ [init[agent, 0], final[agent, 0]],
86
+ [init[agent, 1], final[agent, 1]],
87
+ linewidth=2,
88
+ dashes=[4, 4],
89
+ alpha=0.3,
90
+ )
91
+ for (init, final) in zip(
92
+ self.scene.pedestrians_positions,
93
+ self.scene.final_pedestrians_positions,
94
+ )
95
+ ]
96
+ for agent in range(self.scene.pedestrians_positions.shape[1])
97
+ ]
98
+
99
+ def _set_agent_patches(self):
100
+ """
101
+ Set the agent patches at their current position in the scene.
102
+ """
103
+ current_step = int(round(self.scene.current_time / self.scene.dt))
104
+ self._ego_box = Rectangle(
105
+ xy=(
106
+ -self.scene.ego_length / 2
107
+ + self.scene.ego_ref_speed * self.scene.current_time,
108
+ -self.scene.ego_width / 2,
109
+ ),
110
+ height=self.scene.ego_width,
111
+ width=self.scene.ego_length,
112
+ fill=True,
113
+ facecolor="red",
114
+ alpha=0.4,
115
+ edgecolor="black",
116
+ )
117
+ self._pedestrian_patches = [
118
+ [
119
+ Ellipse(
120
+ xy=xy,
121
+ width=1,
122
+ height=0.5,
123
+ angle=angle * 180 / np.pi + 90,
124
+ facecolor="blue",
125
+ alpha=0.4,
126
+ edgecolor="black",
127
+ )
128
+ for xy, angle in zip(
129
+ self.scene.pedestrians_trajectories[:, agent, current_step],
130
+ self.scene.pedestrians.angle[:, agent],
131
+ )
132
+ ]
133
+ for agent in range(self.scene.pedestrians_trajectories.shape[1])
134
+ ]
135
+
136
+ def plot_road(self) -> None:
137
+ """
138
+ Plot the road as a two lanes, two sidewalks in straight lines with the ego vehicle. Plot is made in given ax.
139
+ """
140
+ self.ax.add_collection(self._sidewalks_boxes)
141
+ self.ax.add_patch(self._ego_box)
142
+ self.ax.add_line(self._center_line)
143
+ self.ax.add_line(self._ego_path)
144
+ self.rescale()
145
+
146
+ def draw_scene(self, index: int, time=None, prediction=None) -> None:
147
+ """
148
+ Plot the scene of given index (road, ego vehicle with its path, pedestrian with its path)
149
+ Args:
150
+ index: index of the pedestrian in the batch
151
+ time: set current time to this value if not None
152
+ prediction: draw this instead of the actual future if not None
153
+ """
154
+ if time is not None:
155
+ self._set_current_time(time)
156
+ self.plot_road()
157
+ for agent_patch in self._pedestrian_patches:
158
+ self.ax.add_patch(agent_patch[index])
159
+ for agent_patch in self._pedestrian_path:
160
+ self.ax.add_line(agent_patch[index])
161
+ if prediction is not None:
162
+ self.draw_trajectory(prediction)
163
+
164
+ def rescale(self):
165
+ """
166
+ Set the x and y limits to the road shape with a margin.
167
+ """
168
+ self.ax.set_xlim(
169
+ left=-2 * self.scene.ego_length,
170
+ right=self.scene.road_length + self.scene.ego_length,
171
+ )
172
+ self.ax.set_ylim(
173
+ bottom=self.scene.bottom - self.scene.lane_width,
174
+ top=2 * self.scene.lane_width + 2 * self.scene.sidewalks_width,
175
+ )
176
+
177
+ def draw_trajectory(self, prediction, color="b") -> None:
178
+ """
179
+ Plot the given prediction in the scene.
180
+ """
181
+ self.ax.scatter(prediction[..., 0], prediction[..., 1], color=color, alpha=0.3)
182
+
183
+ def draw_all_trajectories(
184
+ self,
185
+ prediction: np.ndarray,
186
+ color=None,
187
+ color_value: np.ndarray = None,
188
+ alpha: float = 0.05,
189
+ label: str = "trajectory",
190
+ ) -> None:
191
+ """
192
+ Plot all the given predictions in the scene
193
+ Args:
194
+ prediction : (batch, n_agents, time, 2) batch of trajectories
195
+ color: regular color name
196
+ color_value : (batch) Optional batch of values for coloring from green to red
197
+ """
198
+
199
+ if color_value is not None:
200
+ min = color_value.min()
201
+ max = color_value.max()
202
+ color_value = 0.9 * (color_value - min) / (max - min)
203
+ for agent in range(prediction.shape[1]):
204
+ for traj, val in zip(prediction[:, agent], color_value[:, agent]):
205
+ color = (val, 1 - val, 0.1)
206
+ self.ax.plot(
207
+ traj[:, 0], traj[:, 1], color=color, alpha=alpha, label=label
208
+ )
209
+ self.ax.scatter(traj[-1, 0], traj[-1, 1], color=color, alpha=alpha)
210
+ cmap = matplotlib.colors.ListedColormap(
211
+ np.linspace(
212
+ [color_value.min(), 1 - color_value.min(), 0.1],
213
+ [color_value.max(), 1 - color_value.max(), 0.1],
214
+ 128,
215
+ )
216
+ )
217
+ norm = matplotlib.colors.Normalize(vmin=min, vmax=max, clip=True)
218
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
219
+ plt.colorbar(sm, label="TTC cost")
220
+ else:
221
+ for agent in range(prediction.shape[1]):
222
+ for traj in prediction:
223
+ self.ax.plot(
224
+ traj[agent, :, 0],
225
+ traj[agent, :, 1],
226
+ color=color,
227
+ alpha=alpha,
228
+ label=label,
229
+ )
230
+ self.ax.scatter(
231
+ prediction[:, agent, -1, 0],
232
+ prediction[:, agent, -1, 1],
233
+ color=color,
234
+ alpha=alpha,
235
+ )
236
+
237
+ def draw_legend(self):
238
+ """Draw legend without repeats and without transparency."""
239
+
240
+ handles, labels = self.ax.get_legend_handles_labels()
241
+ i = np.arange(len(labels))
242
+ filter = np.array([])
243
+ unique_labels = list(set(labels))
244
+ for ul in unique_labels:
245
+ filter = np.append(filter, [i[np.array(labels) == ul][0]])
246
+ filtered_handles = []
247
+ for f in filter:
248
+ handles[int(f)].set_alpha(1)
249
+ filtered_handles.append(handles[int(f)])
250
+ filtered_labels = [labels[int(f)] for f in filter]
251
+ self.ax.legend(filtered_handles, filtered_labels)
252
+
253
+
254
+ # Draw a random scene
255
+ if __name__ == "__main__":
256
+ from risk_biased.utils.config_argparse import config_argparse
257
+
258
+ working_dir = os.path.dirname(os.path.realpath(__file__))
259
+ config_path = os.path.join(
260
+ working_dir, "..", "..", "risk_biased", "config", "learning_config.py"
261
+ )
262
+ config = config_argparse(config_path)
263
+ n_samples = 100
264
+
265
+ scene_params = RandomSceneParams.from_config(config)
266
+ scene_params.batch_size = n_samples
267
+ scene = RandomScene(
268
+ scene_params,
269
+ is_torch=False,
270
+ )
271
+
272
+ plotter = ScenePlotter(scene)
273
+
274
+ plotter.draw_scene(0, time=1)
275
+ plt.tight_layout()
276
+ plt.show()
risk_biased/utils/__init__.py ADDED
File without changes
risk_biased/utils/callbacks.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass
3
+
4
+ from mmcv import Config
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from pydantic import NoneBytes
8
+ import pytorch_lightning as pl
9
+ import torch
10
+ import wandb
11
+
12
+ from risk_biased.scene_dataset.loaders import SceneDataLoaders
13
+ from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
14
+ from risk_biased.scene_dataset.scene_plotter import ScenePlotter
15
+ from risk_biased.utils.cost import (
16
+ DistanceCostNumpy,
17
+ DistanceCostParams,
18
+ TTCCostNumpy,
19
+ TTCCostParams,
20
+ )
21
+ from risk_biased.utils.risk import get_risk_level_sampler
22
+
23
+
24
+ class SwitchTrainingModeCallback(pl.Callback):
25
+ """
26
+ This callback switches between CVAE traning and biasing training for the biased_latent_cvae_model
27
+ Args:
28
+ switch_at_epoch: The number of epoch after which to make the switch. The CVAE is not trained anymore after that point.
29
+ """
30
+
31
+ def __init__(self, switch_at_epoch: int) -> None:
32
+ super().__init__()
33
+ self._switch_at_epoch = switch_at_epoch
34
+ self._train_has_started = False
35
+
36
+ def on_train_start(
37
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
38
+ ) -> None:
39
+ """Store the optimizer list and set the trainer to the first optimizer."""
40
+ self._optimizers = trainer.optimizers
41
+ trainer.optimizers = [self._optimizers[0]]
42
+ self._train_has_started = True
43
+
44
+ def on_epoch_start(
45
+ self, trainer: pl.Trainer, pl_module: pl.LightningModule
46
+ ) -> None:
47
+ """
48
+ Check if the switch should be made and if so,
49
+ set the trainer on the second optimizer.
50
+ """
51
+ if trainer.current_epoch == self._switch_at_epoch and self._train_has_started:
52
+ print("Switching to bias training.")
53
+ pl_module.set_training_mode("bias")
54
+ trainer.optimizers = [self._optimizers[1]]
55
+
56
+
57
+ def get_fast_slow_scenes(params: RandomSceneParams, n_samples: int):
58
+ """Define and return two RandomScene objects, one initialized such that slow
59
+ pedestrians are safer and the other such that fast pedestrians are safer.
60
+
61
+ Args:
62
+ params: dataclass containing the necessary parameters for a RandomScene object
63
+ n_samples: number of samples to draw in each scene
64
+ """
65
+ params = copy.deepcopy(params)
66
+ params.batch_size = n_samples
67
+ scene_safe_slow = RandomScene(
68
+ params,
69
+ is_torch=False,
70
+ )
71
+ percent_right = 0.8
72
+ percent_top = 1.1
73
+ angle = 5 * np.pi / 4
74
+ positions = np.array([[[percent_right, percent_top]]] * n_samples)
75
+ angles = np.array([[angle]] * n_samples)
76
+ scene_safe_slow.set_pedestrians_states(positions, angles)
77
+
78
+ scene_safe_fast = RandomScene(
79
+ params,
80
+ is_torch=False,
81
+ )
82
+ percent_right = 0.8
83
+ percent_top = 0.6
84
+ angle = 5 * np.pi / 4
85
+ positions = np.array([[[percent_right, percent_top]]] * n_samples)
86
+ angles = np.array([[angle]] * n_samples)
87
+ scene_safe_fast.set_pedestrians_states(positions, angles)
88
+ return scene_safe_fast, scene_safe_slow
89
+
90
+
91
+ @dataclass
92
+ class DrawCallbackParams:
93
+ """
94
+ Args:
95
+ scene_params: dataclass parameters for the RandomScene
96
+ dist_cost_params: dataclass parameters for the DistanceCost
97
+ ttc_cost_params: dataclass parameters for the TTCCost
98
+ plot_interval_epoch: number of epochs between each plot drawing
99
+ histogram_interval_epoch: number of epochs between each histogram drawing
100
+ num_steps: number of time steps as defined in the config
101
+ num_steps_future: number of time steps in the future as defined in the config
102
+ risk_distribution: dict object describing a risk distribution
103
+ dt: time step size as defined in the config
104
+ """
105
+
106
+ scene_params: RandomSceneParams
107
+ dist_cost_params: DistanceCostParams
108
+ ttc_cost_params: TTCCostParams
109
+ plot_interval_epoch: int
110
+ histogram_interval_epoch: int
111
+ num_steps: int
112
+ num_steps_future: int
113
+ risk_distribution: dict
114
+ dt: float
115
+
116
+ @staticmethod
117
+ def from_config(cfg: Config):
118
+ return DrawCallbackParams(
119
+ scene_params=RandomSceneParams.from_config(cfg),
120
+ dist_cost_params=DistanceCostParams.from_config(cfg),
121
+ ttc_cost_params=TTCCostParams.from_config(cfg),
122
+ plot_interval_epoch=cfg.plot_interval_epoch,
123
+ histogram_interval_epoch=cfg.histogram_interval_epoch,
124
+ num_steps=cfg.num_steps,
125
+ num_steps_future=cfg.num_steps_future,
126
+ risk_distribution=cfg.risk_distribution,
127
+ dt=cfg.dt,
128
+ )
129
+
130
+
131
+ class HistogramCallback(pl.Callback):
132
+ """Logs histograms of distances, distance cost and ttc cost for the data, the predictions at risk_level=0, the predictions at risk_level=1
133
+ Args:
134
+ params: dataclass defining the necessary parameters
135
+ n_samples: Number of samples to use for the histogram plot
136
+ """
137
+
138
+ def __init__(
139
+ self,
140
+ params: DrawCallbackParams,
141
+ n_samples=1000,
142
+ ):
143
+ super().__init__()
144
+ self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
145
+ params.scene_params, n_samples
146
+ )
147
+ self.num_steps = params.num_steps
148
+ self.n_scenes = n_samples
149
+ self.sample_times = params.scene_params.sample_times
150
+ self.dist_cost_func = DistanceCostNumpy(params.dist_cost_params)
151
+ self.ttc_cost_func = TTCCostNumpy(params.ttc_cost_params)
152
+ self.histogram_interval_epoch = params.histogram_interval_epoch
153
+
154
+ self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(self.sample_times)
155
+
156
+ self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
157
+
158
+ def _log_scene(self, pl_module: pl.LightningModule, scene: RandomScene, name: str):
159
+ """
160
+ Log in WandB three histogram for the given scene: One for the data, one for the predictions at risk_level=0 and one for the predictions at risk_level=1
161
+ Args:
162
+ pl_module: LightningModule object
163
+ scene: RandomScene object
164
+ name: name of the given scene
165
+ """
166
+ ped_trajs = scene.get_pedestrians_trajectories()
167
+ device = pl_module.device
168
+ n_agents = ped_trajs.shape[1]
169
+
170
+ input_traj = ped_trajs[..., : self.num_steps, :]
171
+
172
+ normalized_input, offset = SceneDataLoaders.normalize_trajectory(
173
+ torch.from_numpy(input_traj.astype("float32")).contiguous().to(device)
174
+ )
175
+ mask_input = torch.ones_like(normalized_input[..., 0])
176
+ ego_history = (
177
+ torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32"))
178
+ .expand_as(normalized_input)
179
+ .contiguous()
180
+ .to(device)
181
+ )
182
+ ego_future = (
183
+ torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32"))
184
+ .expand(normalized_input.shape[0], n_agents, -1, -1)
185
+ .contiguous()
186
+ .to(device)
187
+ )
188
+ map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device)
189
+ mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device)
190
+
191
+ pred_riskier = (
192
+ pl_module.predict_step(
193
+ (
194
+ normalized_input,
195
+ mask_input,
196
+ map,
197
+ mask_map,
198
+ offset,
199
+ ego_history,
200
+ ego_future,
201
+ ),
202
+ 0,
203
+ risk_level=self._risk_sampler.get_highest_risk(
204
+ batch_size=self.n_scenes, device=device
205
+ )
206
+ .unsqueeze(1)
207
+ .repeat(1, n_agents),
208
+ )
209
+ .cpu()
210
+ .detach()
211
+ .numpy()
212
+ )
213
+
214
+ pred = (
215
+ pl_module.predict_step(
216
+ (
217
+ normalized_input,
218
+ mask_input,
219
+ map,
220
+ mask_map,
221
+ offset,
222
+ ego_history,
223
+ ego_future,
224
+ ),
225
+ 0,
226
+ risk_level=None,
227
+ )
228
+ .cpu()
229
+ .detach()
230
+ .numpy()
231
+ )
232
+
233
+ ped_trajs_pred = np.concatenate((input_traj, pred), axis=-2)
234
+ ped_trajs_pred_riskier = np.concatenate((input_traj, pred_riskier), axis=-2)
235
+
236
+ travel_distances = np.sqrt(
237
+ np.square(ped_trajs[..., -1, :] - ped_trajs[..., 0, :]).sum(-1)
238
+ )
239
+
240
+ dist_cost, dist = self.dist_cost_func(
241
+ self.ego_traj[..., self.num_steps :, :],
242
+ ped_trajs[..., self.num_steps :, :],
243
+ )
244
+
245
+ ttc_cost, (ttc, dist) = self.ttc_cost_func(
246
+ self.ego_traj[..., self.num_steps :, :],
247
+ ped_trajs[..., self.num_steps :, :],
248
+ scene.get_ego_ref_velocity(),
249
+ scene.get_pedestrians_velocities(),
250
+ )
251
+
252
+ travel_distances_pred = np.sqrt(
253
+ np.square(ped_trajs_pred[..., -1, :] - ped_trajs_pred[..., 0, :]).sum(-1)
254
+ )
255
+ dist_cost_pred, dist_pred = self.dist_cost_func(
256
+ self.ego_traj[..., self.num_steps :, :],
257
+ ped_trajs_pred[..., self.num_steps :, :],
258
+ )
259
+ sample_times = np.array(self.sample_times)
260
+ ped_velocities_pred = (
261
+ ped_trajs_pred[..., 1:, :] - ped_trajs_pred[..., :-1, :]
262
+ ) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None])
263
+ ped_velocities_pred = np.concatenate(
264
+ (ped_velocities_pred[..., 0:1, :], ped_velocities_pred), -2
265
+ )
266
+ ttc_cost_pred, (ttc_pred, dist_pred) = self.ttc_cost_func(
267
+ self.ego_traj[..., self.num_steps :, :],
268
+ ped_trajs_pred[..., self.num_steps :, :],
269
+ scene.get_ego_ref_velocity(),
270
+ ped_velocities_pred[..., self.num_steps :, :],
271
+ )
272
+
273
+ travel_distances_pred_riskier = np.sqrt(
274
+ np.square(
275
+ ped_trajs_pred_riskier[..., -1, :] - ped_trajs_pred_riskier[..., 0, :]
276
+ ).sum(-1)
277
+ )
278
+
279
+ dist_cost_pred_riskier, dist_pred_riskier = self.dist_cost_func(
280
+ self.ego_traj[..., self.num_steps :, :],
281
+ ped_trajs_pred_riskier[..., self.num_steps :, :],
282
+ )
283
+ sample_times = np.array(self.sample_times)
284
+ ped_velocities_pred_riskier = (
285
+ ped_trajs_pred_riskier[..., 1:, :] - ped_trajs_pred_riskier[..., :-1, :]
286
+ ) / ((sample_times[1:] - sample_times[:-1])[None, None, :, None])
287
+ ped_velocities_pred_riskier = np.concatenate(
288
+ (ped_velocities_pred_riskier[..., 0:1, :], ped_velocities_pred_riskier), -2
289
+ )
290
+ ttc_cost_pred_riskier, (ttc_pred, dist_pred_riskier) = self.ttc_cost_func(
291
+ self.ego_traj[..., self.num_steps :, :],
292
+ ped_trajs_pred_riskier[..., self.num_steps :, :],
293
+ scene.get_ego_ref_velocity(),
294
+ ped_velocities_pred_riskier[..., self.num_steps :, :],
295
+ )
296
+ data = [
297
+ [dist, dist_pred, dist_risk]
298
+ for (dist, dist_pred, dist_risk) in zip(
299
+ travel_distances.flatten(),
300
+ travel_distances_pred.flatten(),
301
+ travel_distances_pred_riskier.flatten(),
302
+ )
303
+ ]
304
+ table_travel_distance = wandb.Table(
305
+ data=data,
306
+ columns=[
307
+ "Travel distance data " + name,
308
+ "Travel distance prediction " + name,
309
+ "Travel distance riskier " + name,
310
+ ],
311
+ )
312
+ data = [
313
+ [cost, cost_pred, cost_risk]
314
+ for (cost, cost_pred, cost_risk) in zip(
315
+ dist_cost.flatten(),
316
+ dist_cost_pred.flatten(),
317
+ dist_cost_pred_riskier.flatten(),
318
+ )
319
+ ]
320
+ table_distance_cost = wandb.Table(
321
+ data=data,
322
+ columns=[
323
+ "Distance cost data " + name,
324
+ "Distance cost prediction " + name,
325
+ "Distance cost riskier " + name,
326
+ ],
327
+ )
328
+ data = [
329
+ [ttc, ttc_pred, ttc_risk]
330
+ for (ttc, ttc_pred, ttc_risk) in zip(
331
+ ttc_cost.flatten(),
332
+ ttc_cost_pred.flatten(),
333
+ ttc_cost_pred_riskier.flatten(),
334
+ )
335
+ ]
336
+ table_ttc_cost = wandb.Table(
337
+ data=data,
338
+ columns=[
339
+ "TTC cost data " + name,
340
+ "TTC cost prediction " + name,
341
+ "TTC cost riskier " + name,
342
+ ],
343
+ )
344
+ wandb.log(
345
+ {
346
+ "Travel distance data "
347
+ + name: wandb.plot_table(
348
+ vega_spec_name="jmercat/histogram_01_bins",
349
+ data_table=table_travel_distance,
350
+ fields={
351
+ "value": "Travel distance data " + name,
352
+ "title": "Travel distance data " + name,
353
+ },
354
+ ),
355
+ "Travel distance prediction "
356
+ + name: wandb.plot_table(
357
+ vega_spec_name="jmercat/histogram_01_bins",
358
+ data_table=table_travel_distance,
359
+ fields={
360
+ "value": "Travel distance prediction " + name,
361
+ "title": "Travel distance prediction " + name,
362
+ },
363
+ ),
364
+ "Travel distance riskier "
365
+ + name: wandb.plot_table(
366
+ vega_spec_name="jmercat/histogram_01_bins",
367
+ data_table=table_travel_distance,
368
+ fields={
369
+ "value": "Travel distance riskier " + name,
370
+ "title": "Travel distance riskier " + name,
371
+ },
372
+ ),
373
+ "Distance cost data "
374
+ + name: wandb.plot_table(
375
+ vega_spec_name="jmercat/histogram_0025_bins",
376
+ data_table=table_distance_cost,
377
+ fields={
378
+ "value": "Distance cost data " + name,
379
+ "title": "Distance cost data " + name,
380
+ },
381
+ ),
382
+ "Distance cost prediction "
383
+ + name: wandb.plot_table(
384
+ vega_spec_name="jmercat/histogram_0025_bins",
385
+ data_table=table_distance_cost,
386
+ fields={
387
+ "value": "Distance cost prediction " + name,
388
+ "title": "Distance cost prediction " + name,
389
+ },
390
+ ),
391
+ "Distance cost riskier "
392
+ + name: wandb.plot_table(
393
+ vega_spec_name="jmercat/histogram_0025_bins",
394
+ data_table=table_distance_cost,
395
+ fields={
396
+ "value": "Distance cost riskier " + name,
397
+ "title": "Distance cost riskier " + name,
398
+ },
399
+ ),
400
+ "TTC cost data "
401
+ + name: wandb.plot_table(
402
+ vega_spec_name="jmercat/histogram_005_bins",
403
+ data_table=table_ttc_cost,
404
+ fields={
405
+ "value": "TTC cost data " + name,
406
+ "title": "TTC cost data " + name,
407
+ },
408
+ ),
409
+ "TTC cost prediction "
410
+ + name: wandb.plot_table(
411
+ vega_spec_name="jmercat/histogram_005_bins",
412
+ data_table=table_ttc_cost,
413
+ fields={
414
+ "value": "TTC cost prediction " + name,
415
+ "title": "TTC cost prediction " + name,
416
+ },
417
+ ),
418
+ "TTC cost riskier "
419
+ + name: wandb.plot_table(
420
+ vega_spec_name="jmercat/histogram_005_bins",
421
+ data_table=table_ttc_cost,
422
+ fields={
423
+ "value": "TTC cost riskier " + name,
424
+ "title": "TTC cost riskier " + name,
425
+ },
426
+ ),
427
+ }
428
+ )
429
+
430
+ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
431
+ """After a validation at the end of every histogram_interval_epoch,
432
+ log the histograms for two scenes: the safer fast scene and the safer slow scene.
433
+ """
434
+ if (
435
+ trainer.current_epoch % self.histogram_interval_epoch
436
+ == self.histogram_interval_epoch - 1
437
+ ):
438
+ self._log_scene(pl_module, self.scene_safe_fast, name="Safer fast")
439
+ self._log_scene(pl_module, self.scene_safe_slow, name="Safer slow")
440
+
441
+
442
+ class PlotTrajCallback(pl.Callback):
443
+ """Plot trajectory samples for two scenes:
444
+ One that is safer for the slow pedestrians
445
+ One that is safer for the fast pedestrians
446
+ Samples of ground truth, prediction, and biased predictions are superposed.
447
+ Last positions are marked to visualize the clusters.
448
+
449
+ Args:
450
+ params: dataclass containing the necessary parameters for a
451
+ n_samples: number of sample trajectories to draw
452
+ """
453
+
454
+ def __init__(
455
+ self,
456
+ params: DrawCallbackParams,
457
+ n_samples: int = 1,
458
+ ):
459
+ super().__init__()
460
+ self.n_samples = n_samples
461
+ self.num_steps = params.num_steps
462
+ self.dt = params.scene_params.dt
463
+ self.scene_params = params.scene_params
464
+ self.plot_interval_epoch = params.plot_interval_epoch
465
+ self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
466
+ params.scene_params, n_samples
467
+ )
468
+ self.ego_traj = self.scene_safe_fast.get_ego_ref_trajectory(
469
+ params.scene_params.sample_times
470
+ )
471
+ self._risk_sampler = get_risk_level_sampler(params.risk_distribution)
472
+
473
+ def _log_scene(self, epoch: int, pl_module, scene: RandomScene, name: str) -> None:
474
+ """Add drawing of samples of prediction, biased prediction and ground truth in the scene.
475
+
476
+ Args:
477
+ epoch: current epoch calling the log
478
+ pl_module: pytorch lightning module being trained
479
+ scene: scene to draw
480
+ name: name of the scene
481
+ """
482
+ ped_trajs = scene.get_pedestrians_trajectories()
483
+ device = pl_module.device
484
+ n_agents = ped_trajs.shape[1]
485
+
486
+ input_traj = ped_trajs[..., : self.num_steps, :]
487
+
488
+ normalized_input, offset = SceneDataLoaders.normalize_trajectory(
489
+ torch.from_numpy(input_traj.astype("float32")).contiguous().to(device)
490
+ )
491
+ mask_input = torch.ones_like(normalized_input[..., 0])
492
+ ego_history = (
493
+ torch.from_numpy(self.ego_traj[..., : self.num_steps, :].astype("float32"))
494
+ .expand_as(normalized_input)
495
+ .contiguous()
496
+ .to(device)
497
+ )
498
+ ego_future = (
499
+ torch.from_numpy(self.ego_traj[..., self.num_steps :, :].astype("float32"))
500
+ .expand(normalized_input.shape[0], n_agents, -1, -1)
501
+ .contiguous()
502
+ .to(device)
503
+ )
504
+ map = torch.empty(ego_history.shape[0], 0, 0, 2, device=mask_input.device)
505
+ mask_map = torch.empty(ego_history.shape[0], 0, 0, device=mask_input.device)
506
+
507
+ pred_riskier = (
508
+ pl_module.predict_step(
509
+ (
510
+ normalized_input,
511
+ mask_input,
512
+ map,
513
+ mask_map,
514
+ offset,
515
+ ego_history,
516
+ ego_future,
517
+ ),
518
+ 0,
519
+ risk_level=self._risk_sampler.get_highest_risk(
520
+ batch_size=self.n_samples, device=device
521
+ )
522
+ .unsqueeze(1)
523
+ .repeat(1, n_agents),
524
+ )
525
+ .cpu()
526
+ .detach()
527
+ .numpy()
528
+ )
529
+
530
+ pred = (
531
+ pl_module.predict_step(
532
+ (
533
+ normalized_input,
534
+ mask_input,
535
+ map,
536
+ mask_map,
537
+ offset,
538
+ ego_history,
539
+ ego_future,
540
+ ),
541
+ 0,
542
+ risk_level=None,
543
+ )
544
+ .cpu()
545
+ .detach()
546
+ .numpy()
547
+ )
548
+
549
+ fig, ax = plt.subplots()
550
+ plotter = ScenePlotter(scene, ax=ax)
551
+ fig.set_size_inches(h=scene.road_width / 3 + 1, w=scene.road_length / 3)
552
+
553
+ time = self.dt * self.num_steps
554
+ plotter.draw_scene(0, time=time)
555
+ alpha = 0.5 / np.log(self.n_samples)
556
+ plotter.draw_all_trajectories(
557
+ ped_trajs[..., self.num_steps :, :],
558
+ color="g",
559
+ alpha=alpha,
560
+ label="Future ground truth",
561
+ )
562
+ plotter.draw_all_trajectories(
563
+ input_traj, color="b", alpha=alpha, label="Past input"
564
+ )
565
+ plotter.draw_all_trajectories(
566
+ pred, color="orange", alpha=alpha, label="Prediction"
567
+ )
568
+ plotter.draw_all_trajectories(
569
+ pred_riskier, color="r", alpha=alpha, label="Prediction risk-seeking"
570
+ )
571
+ plotter.draw_legend()
572
+ plt.tight_layout()
573
+ wandb.log({"Road scene " + name: wandb.Image(fig), "epoch": epoch})
574
+ plt.close()
575
+
576
+ def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
577
+ """After a validation at the end of every plot_interval_epoch,
578
+ log the prediction samples for two scenes: the safer fast scene and the safer slow scene.
579
+ """
580
+ if (
581
+ trainer.current_epoch % self.plot_interval_epoch
582
+ == self.plot_interval_epoch - 1
583
+ ):
584
+ self.scene_safe_fast, self.scene_safe_slow = get_fast_slow_scenes(
585
+ self.scene_params, self.n_samples
586
+ )
587
+ self._log_scene(
588
+ trainer.current_epoch, pl_module, self.scene_safe_slow, "Safer slow"
589
+ )
590
+ self._log_scene(
591
+ trainer.current_epoch, pl_module, self.scene_safe_fast, "Safer fast"
592
+ )
593
+
594
+
595
+ # TODO: make the same kind of logs for the Waymo dataset
risk_biased/utils/config_argparse.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, Union, List
3
+ import warnings
4
+
5
+ import argparse
6
+ from mmcv import Config
7
+
8
+
9
+ def config_argparse(config_path: Optional[Union[str, List[str]]] = None) -> Config:
10
+ """Function that loads the config file as an MMCV Config object and overwrites its values with argparsed arguments.
11
+
12
+ Args:
13
+ config_path : path of the mmcv config file
14
+
15
+ Returns:
16
+ MMCV Config object with default values from the config_path and overwritten values from argparse
17
+ """
18
+ if config_path is None:
19
+ working_dir = os.path.dirname(os.path.realpath(__file__))
20
+ config_path = os.path.join(
21
+ working_dir, "..", "..", "config", "learning_config.py"
22
+ )
23
+ if isinstance(config_path, str):
24
+ cfg = Config.fromfile(config_path)
25
+ else:
26
+ cfg = Config.fromfile(config_path[0])
27
+ for path in config_path[1:]:
28
+ c = Config.fromfile(path)
29
+ cfg.update(c)
30
+
31
+ parser = argparse.ArgumentParser()
32
+ excluded_args = ["force_config", "load_last"]
33
+ overwritable_types = (str, float, int, list)
34
+ for key, value in cfg.items():
35
+ if key not in excluded_args:
36
+ if list in overwritable_types and isinstance(value, list):
37
+ if len(value) > 0:
38
+ parser.add_argument(
39
+ "--" + key, default=value, nargs="+", type=type(value[0])
40
+ )
41
+ else:
42
+ parser.add_argument("--" + key, default=value, nargs="+")
43
+ elif isinstance(value, overwritable_types):
44
+ parser.add_argument("--" + key, default=value, type=type(value))
45
+
46
+ if "load_from" not in cfg.keys():
47
+ parser.add_argument(
48
+ "--load_from",
49
+ default="",
50
+ type=str,
51
+ help="""Use this to load the model weights from a wandb checkpoint,
52
+ refer to the checkpoint with the wandb id (example:'1f1ho81a')""",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--load_last",
57
+ action="store_true",
58
+ help="""Use this flag to force the use of the last checkpoint instead of the best one
59
+ when loading a model.""",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--force_config",
64
+ action="store_true",
65
+ help="""Use this flag to force the use of the local config file
66
+ when loading a model from a checkpoint. Otherwise the checkpoint config file is used.
67
+ In any case the parameters can be overwritten with an argparse argument.""",
68
+ )
69
+ if "force_config" not in cfg.keys():
70
+ parser.set_defaults(force_config=False)
71
+ else:
72
+ parser.set_defaults(force_config=cfg.force_config)
73
+
74
+ if "load_last" not in cfg.keys():
75
+ parser.set_defaults(load_last=False)
76
+ else:
77
+ parser.set_defaults(force_config=cfg.force_config)
78
+
79
+ args = parser.parse_args()
80
+
81
+ # Print a warning in case the parameter 'dt' or 'time_scene' is changed becaus 'sample_times' might need to be updated accordingly.
82
+ if (
83
+ args.dt != cfg.dt or args.time_scene != cfg.time_scene
84
+ ) and args.sample_times == cfg.sample_times:
85
+ warnings.warn(
86
+ f"""Parameter 'dt' has been changed from {args.dataset_parameters['dt']} to {args.dt} by
87
+ a command line argument, it might be used to set the parameter 'sample_times' that
88
+ cannot be updated accordingly. Consider setting 'dt' in {config_path} instead."""
89
+ )
90
+ # Config has a dataset_parameters field that copies the parameters related to dataset to compare them,
91
+ # they must be updated too if some of the dataset parameters were changed by argparse
92
+ for key, value in cfg.dataset_parameters.items():
93
+ if isinstance(value, overwritable_types):
94
+ cfg.dataset_parameters[key] = args.__getattribute__(key)
95
+ cfg.update(args.__dict__)
96
+ return cfg
risk_biased/utils/cost.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Optional, Tuple
3
+
4
+ from mmcv import Config
5
+ import numpy as np
6
+ import torch
7
+ from torch import Tensor
8
+
9
+
10
+ def masked_min_torch(x, mask=None, dim=None):
11
+ if mask is not None:
12
+ x = torch.masked_fill(x, torch.logical_not(mask), float("inf"))
13
+ if dim is None:
14
+ return torch.min(x)
15
+ else:
16
+ return torch.min(x, dim=dim)[0]
17
+
18
+
19
+ def masked_max_torch(x, mask=None, dim=None):
20
+ if mask is not None:
21
+ x = torch.masked_fill(x, torch.logical_not(mask), float("-inf"))
22
+ if dim is None:
23
+ return torch.max(x)
24
+ else:
25
+ return torch.max(x, dim=dim)[0]
26
+
27
+
28
+ def get_masked_discounted_mean_torch(discount_factor=0.95):
29
+ def masked_discounted_mean_torch(x, mask=None, dim=None):
30
+ discount_tensor = torch.full(x.shape, discount_factor, device=x.device)
31
+ discount_tensor = torch.cumprod(discount_tensor, dim=-2)
32
+ if mask is not None:
33
+ x = torch.masked_fill(x, torch.logical_not(mask), 0)
34
+ if dim is None:
35
+ assert mask.any()
36
+ return (x * discount_tensor).sum() / (mask * discount_tensor).sum()
37
+ else:
38
+ return (x * discount_tensor).sum(dim) / (mask * discount_tensor).sum(
39
+ dim
40
+ ).clamp_min(1)
41
+ else:
42
+ if dim is None:
43
+ return (x * discount_tensor).sum() / discount_tensor.sum()
44
+ else:
45
+ return (x * discount_tensor).sum(dim) / discount_tensor.sum(dim)
46
+
47
+ return masked_discounted_mean_torch
48
+
49
+
50
+ def masked_mean_torch(x, mask=None, dim=None):
51
+ if mask is not None:
52
+ x = torch.masked_fill(x, torch.logical_not(mask), 0)
53
+ if dim is None:
54
+ assert mask.any()
55
+ return x.sum() / mask.sum()
56
+ else:
57
+ return x.sum(dim) / mask.sum(dim).clamp_min(1)
58
+ else:
59
+ if dim is None:
60
+ return x.mean()
61
+ else:
62
+ return x.mean(dim)
63
+
64
+
65
+ def get_discounted_mean_np(discount_factor=0.95):
66
+ def discounted_mean_np(x, axis=None):
67
+ discount_tensor = np.full(x.shape, discount_factor)
68
+ discount_tensor = np.cumprod(discount_tensor, axis=-2)
69
+ if axis is None:
70
+ return (x * discount_tensor).sum() / discount_tensor.sum()
71
+ else:
72
+ return (x * discount_tensor).sum(axis) / discount_tensor.sum(axis)
73
+
74
+ return discounted_mean_np
75
+
76
+
77
+ def get_masked_reduce_np(reduce_function):
78
+ def masked_reduce_np(x, mask=None, axis=None):
79
+ if mask is not None:
80
+ x = np.ma.array(x, mask=np.logical_not(mask))
81
+ return reduce_function(x, axis=axis)
82
+ else:
83
+ return reduce_function(x, axis=axis)
84
+
85
+ return masked_reduce_np
86
+
87
+
88
+ @dataclass
89
+ class CostParams:
90
+ scale: float
91
+ reduce: str
92
+ discount_factor: float
93
+
94
+ @staticmethod
95
+ def from_config(cfg: Config):
96
+ return CostParams(
97
+ scale=cfg.cost_scale,
98
+ reduce=cfg.cost_reduce,
99
+ discount_factor=cfg.discount_factor,
100
+ )
101
+
102
+
103
+ class BaseCostTorch:
104
+ """Base cost class defining reduce strategy and basic parameters.
105
+ Its __call__ definition is only a dummy example returning zeros, this class is intended to be
106
+ inherited from and __call__ redefined with an actual cost between the inputs.
107
+ """
108
+
109
+ def __init__(self, params: CostParams) -> None:
110
+ super().__init__()
111
+ self._reduce_fun = params.reduce
112
+ self.scale = params.scale
113
+
114
+ reduce_fun_torch_dict = {
115
+ "min": masked_min_torch,
116
+ "max": masked_max_torch,
117
+ "mean": masked_mean_torch,
118
+ "discounted_mean": get_masked_discounted_mean_torch(params.discount_factor),
119
+ "now": lambda *args, **kwargs: args[0][..., 0],
120
+ "final": lambda *args, **kwargs: args[0][..., -1],
121
+ }
122
+
123
+ self._reduce_fun = reduce_fun_torch_dict[params.reduce]
124
+
125
+ @property
126
+ def distance_bandwidth(self):
127
+ return 1
128
+
129
+ @property
130
+ def time_bandwidth(self):
131
+ return 1
132
+
133
+ def __call__(
134
+ self,
135
+ x1: Tensor,
136
+ x2: Tensor,
137
+ v1: Tensor,
138
+ v2: Tensor,
139
+ mask: Optional[Tensor] = None,
140
+ ) -> Tuple[Tensor, Any]:
141
+ """Compute the cost from given positions x1, x2 and velocities v1, v2
142
+ The base cost only returns 0 cost, use costs that inherit from this to compute an actual cost.
143
+
144
+ Args:
145
+ x1 (some shape, num_steps, 2): positions of the first agent
146
+ x2 (some shape, num_steps, 2): positions of the second agent
147
+ v1 (some shape, num_steps, 2): velocities of the first agent
148
+ v2 (some shape, num_steps, 2): velocities of the second agent
149
+ mask (some_shape, num_steps, 2): mask set to True where the cost should be computed
150
+
151
+ Returns:
152
+ (some_shape) cost for the compared states of agent 1 and agent 2, as well as any
153
+ supplementary cost-related information
154
+ """
155
+ return (
156
+ self._reduce_fun(torch.zeros_like(x2[..., 0]), mask, dim=-1),
157
+ None,
158
+ )
159
+
160
+
161
+ class BaseCostNumpy:
162
+ """Base cost class defining reduce strategy and basic parameters.
163
+ Its __call__ definition is only a dummy example returning zeros, this class is intended to be
164
+ inherited from and __call__ redefined with an actual cost between the inputs.
165
+ """
166
+
167
+ def __init__(self, params: CostParams) -> None:
168
+ super().__init__()
169
+ self._reduce_fun = params.reduce
170
+ self.scale = params.scale
171
+
172
+ reduce_fun_np_dict = {
173
+ "min": get_masked_reduce_np(np.min),
174
+ "max": get_masked_reduce_np(np.max),
175
+ "mean": get_masked_reduce_np(np.mean),
176
+ "discounted_mean": get_masked_reduce_np(
177
+ get_discounted_mean_np(params.discount_factor)
178
+ ),
179
+ "now": get_masked_reduce_np(lambda *args, **kwargs: args[0][..., 0]),
180
+ "final": get_masked_reduce_np(lambda *args, **kwargs: args[0][..., -1]),
181
+ }
182
+ self._reduce_fun = reduce_fun_np_dict[params.reduce]
183
+
184
+ @property
185
+ def distance_bandwidth(self):
186
+ return 1
187
+
188
+ @property
189
+ def time_bandwidth(self):
190
+ return 1
191
+
192
+ def __call__(
193
+ self,
194
+ x1: np.ndarray,
195
+ x2: np.ndarray,
196
+ v1: np.ndarray,
197
+ v2: np.ndarray,
198
+ mask: Optional[np.ndarray] = None,
199
+ ) -> Tuple[np.ndarray, Any]:
200
+ """Compute the cost from given positions x1, x2 and velocities v1, v2
201
+ The base cost only returns 0 cost, use costs that inherit from this to compute an actual cost.
202
+
203
+ Args:
204
+ x1 (some shape, num_steps, 2): positions of the first agent
205
+ x2 (some shape, num_steps, 2): positions of the second agent
206
+ v1 (some shape, num_steps, 2): velocities of the first agent
207
+ v2 (some shape, num_steps, 2): velocities of the second agent
208
+ mask (some_shape, num_steps, 2): mask set to True where the cost should be computed
209
+
210
+ Returns:
211
+ (some_shape) cost for the compared states of agent 1 and agent 2, as well as any
212
+ supplementary cost-related information
213
+ """
214
+ return (
215
+ self._reduce_fun(np.zeros_like(x2[..., 0]), mask, axis=-1),
216
+ None,
217
+ )
218
+
219
+
220
+ @dataclass
221
+ class DistanceCostParams(CostParams):
222
+ bandwidth: float
223
+
224
+ @staticmethod
225
+ def from_config(cfg: Config):
226
+ return DistanceCostParams(
227
+ scale=cfg.cost_scale,
228
+ reduce=cfg.cost_reduce,
229
+ bandwidth=cfg.distance_bandwidth,
230
+ discount_factor=cfg.discount_factor,
231
+ )
232
+
233
+
234
+ class DistanceCostTorch(BaseCostTorch):
235
+ def __init__(self, params: DistanceCostParams) -> None:
236
+ super().__init__(params)
237
+ self._bandwidth = params.bandwidth
238
+
239
+ @property
240
+ def distance_bandwidth(self):
241
+ return self._bandwidth
242
+
243
+ def __call__(
244
+ self, x1: Tensor, x2: Tensor, *args, mask: Optional[Tensor] = None, **kwargs
245
+ ) -> Tuple[Tensor, Tensor]:
246
+ """
247
+ Returns a cost estimation based on distance. Also returns distances between ego and pedestrians.
248
+ Args:
249
+ x1: First agent trajectory
250
+ x2: Second agent trajectory
251
+ mask: True where cost should be computed
252
+ Returns:
253
+ cost, distance_to_collision
254
+ """
255
+
256
+ dist = torch.square(x2 - x1).sum(-1)
257
+ if mask is not None:
258
+ dist = torch.masked_fill(dist, torch.logical_not(mask), 1e9)
259
+ cost = torch.exp(-dist / (2 * self._bandwidth))
260
+ return self.scale * self._reduce_fun(cost, mask=mask, dim=-1), dist
261
+
262
+
263
+ class DistanceCostNumpy(BaseCostNumpy):
264
+ def __init__(self, params: DistanceCostParams) -> None:
265
+ super().__init__(params)
266
+ self._bandwidth = params.bandwidth
267
+
268
+ @property
269
+ def distance_bandwidth(self):
270
+ return self._bandwidth
271
+
272
+ def __call__(
273
+ self,
274
+ x1: np.ndarray,
275
+ x2: np.ndarray,
276
+ *args,
277
+ mask: Optional[np.ndarray] = None,
278
+ **kwargs
279
+ ) -> Tuple[np.ndarray, np.ndarray]:
280
+ """
281
+ Returns a cost estimation based on distance. Also returns distances between ego and pedestrians.
282
+ Args:
283
+ x1: First agent trajectory
284
+ x2: Second agent trajectory
285
+ mask: True where cost should be computed
286
+ Returns:
287
+ cost, distance_to_collision
288
+ """
289
+ dist = np.square(x2 - x1).sum(-1)
290
+ if mask is not None:
291
+ dist = np.where(mask, dist, 1e9)
292
+ cost = np.exp(-dist / (2 * self._bandwidth))
293
+ return self.scale * self._reduce_fun(cost, mask=mask, axis=-1), dist
294
+
295
+
296
+ @dataclass
297
+ class TTCCostParams(CostParams):
298
+ distance_bandwidth: float
299
+ time_bandwidth: float
300
+ min_velocity_diff: float
301
+
302
+ @staticmethod
303
+ def from_config(cfg: Config):
304
+ return TTCCostParams(
305
+ scale=cfg.cost_scale,
306
+ reduce=cfg.cost_reduce,
307
+ distance_bandwidth=cfg.distance_bandwidth,
308
+ time_bandwidth=cfg.time_bandwidth,
309
+ min_velocity_diff=cfg.min_velocity_diff,
310
+ discount_factor=cfg.discount_factor,
311
+ )
312
+
313
+
314
+ class TTCCostTorch(BaseCostTorch):
315
+ def __init__(self, params: TTCCostParams) -> None:
316
+ super().__init__(params)
317
+ self._d_bw = params.distance_bandwidth
318
+ self._t_bw = params.time_bandwidth
319
+ self._min_v = params.min_velocity_diff
320
+
321
+ @property
322
+ def distance_bandwidth(self):
323
+ return self._d_bw
324
+
325
+ @property
326
+ def time_bandwidth(self):
327
+ return self._t_bw
328
+
329
+ def __call__(
330
+ self,
331
+ x1: Tensor,
332
+ x2: Tensor,
333
+ v1: Tensor,
334
+ v2: Tensor,
335
+ *args,
336
+ mask: Optional[Tensor] = None,
337
+ **kwargs
338
+ ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
339
+ """
340
+ Returns a cost estimation based on time to collision and distance to collision.
341
+ Also returns the estimated time to collision, and the imaginary part of the time to collision.
342
+ Args:
343
+ x1: (some_shape, sequence_length, feature_shape) Initial position of the first agent
344
+ x2: (some_shape, sequence_length, feature_shape) Initial position of the second agent
345
+ v1: (some_shape, sequence_length, feature_shape) Velocity of the first agent
346
+ v2: (some_shape, sequence_length, feature_shape) Velocity of the second agent
347
+ mask: (some_shape, sequence_length) True where cost should be computed
348
+ Returns:
349
+ cost, (time_to_collision, distance_to_collision)
350
+ """
351
+ pos_diff = x1 - x2
352
+ velocity_diff = v1 - v2
353
+
354
+ dx = pos_diff[..., 0]
355
+ dy = pos_diff[..., 1]
356
+ vx = velocity_diff[..., 0]
357
+ vy = velocity_diff[..., 1]
358
+
359
+ speed_diff = (
360
+ torch.square(velocity_diff).sum(-1).clamp(self._min_v * self._min_v, None)
361
+ )
362
+
363
+ TTC = -(dx * vx + dy * vy) / speed_diff
364
+
365
+ distance_TTC = torch.where(
366
+ TTC < 0,
367
+ torch.sqrt(dx * dx + dy * dy),
368
+ torch.abs(vy * dx - vx * dy) / torch.sqrt(speed_diff),
369
+ )
370
+ TTC = torch.relu(TTC)
371
+ if mask is not None:
372
+ TTC = torch.masked_fill(TTC, torch.logical_not(mask), 1e9)
373
+ distance_TTC = torch.masked_fill(distance_TTC, torch.logical_not(mask), 1e9)
374
+
375
+ cost = self.scale * self._reduce_fun(
376
+ torch.exp(
377
+ -torch.square(TTC) / (2 * self._t_bw)
378
+ - torch.square(distance_TTC) / (2 * self._d_bw)
379
+ ),
380
+ mask=mask,
381
+ dim=-1,
382
+ )
383
+
384
+ return cost, (TTC, distance_TTC)
385
+
386
+
387
+ class TTCCostNumpy(BaseCostNumpy):
388
+ def __init__(self, params: TTCCostParams) -> None:
389
+ super().__init__(params)
390
+ self._d_bw = params.distance_bandwidth
391
+ self._t_bw = params.time_bandwidth
392
+ self._min_v = params.min_velocity_diff
393
+
394
+ @property
395
+ def distance_bandwidth(self):
396
+ return self._d_bw
397
+
398
+ @property
399
+ def time_bandwidth(self):
400
+ return self._t_bw
401
+
402
+ def __call__(
403
+ self,
404
+ x1: np.ndarray,
405
+ x2: np.ndarray,
406
+ v1: np.ndarray,
407
+ v2: np.ndarray,
408
+ *args,
409
+ mask: Optional[np.ndarray] = None,
410
+ **kwargs
411
+ ) -> Tuple[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
412
+ """
413
+ Returns a cost estimation based on time to collision and distance to collision.
414
+ Also returns the estimated time to collision, and the imaginary part of the time to collision.
415
+ Args:
416
+ x1: (some_shape, sequence_length, feature_shape) Initial position of the first agent
417
+ x2: (some_shape, sequence_length, feature_shape) Initial position of the second agent
418
+ v1: (some_shape, sequence_length, feature_shape) Velocity of the first agent
419
+ v2: (some_shape, sequence_length, feature_shape) Velocity of the second agent
420
+ mask: (some_shape, sequence_length) True where cost should be computed
421
+ Returns:
422
+ cost, (time_to_collision, distance_to_collision)
423
+ """
424
+ pos_diff = x1 - x2
425
+ velocity_diff = v1 - v2
426
+
427
+ dx = pos_diff[..., 0]
428
+ dy = pos_diff[..., 1]
429
+ vx = velocity_diff[..., 0]
430
+ vy = velocity_diff[..., 1]
431
+
432
+ speed_diff = np.maximum(
433
+ np.square(velocity_diff).sum(-1), self._min_v * self._min_v
434
+ )
435
+
436
+ TTC = -(dx * vx + dy * vy) / speed_diff
437
+ distance_TTC = np.where(
438
+ TTC < 0,
439
+ np.sqrt(dx * dx + dy * dy),
440
+ np.abs(vy * dx - vx * dy) / np.sqrt(speed_diff),
441
+ )
442
+ TTC = np.where(
443
+ TTC < 0,
444
+ 0,
445
+ TTC,
446
+ )
447
+ if mask is not None:
448
+ TTC = np.where(mask, TTC, 1e9)
449
+ distance_TTC = np.where(mask, TTC, 1e9)
450
+
451
+ cost = self.scale * self._reduce_fun(
452
+ np.exp(
453
+ -np.square(TTC) / (2 * self._t_bw)
454
+ - np.square(distance_TTC) / (2 * self._d_bw)
455
+ ),
456
+ mask=mask,
457
+ axis=-1,
458
+ )
459
+ return cost, (TTC, distance_TTC)
460
+
461
+
462
+ def compute_v_from_x(x: Tensor, y: Tensor, dt: float):
463
+ """
464
+ Computes the velocity from the position and the time difference.
465
+ Args:
466
+ x: (some_shape, past_time_sequence, features) Past positions of the agents
467
+ y: (some_shape, future_time_sequence, features) Future positions of the agents
468
+ dt: Time difference
469
+ Returns:
470
+ v: (some_shape, future_time_sequence, features) Velocity of the agents
471
+ """
472
+ v = (y[..., 1:, :2] - y[..., :-1, :2]) / dt
473
+ v_0 = (y[..., 0:1, :2] - x[..., -1:, :2]) / dt
474
+ v = torch.cat((v_0, v), -2)
475
+ return v
476
+
477
+
478
+ def get_cost(
479
+ cost_function: BaseCostTorch,
480
+ x: torch.Tensor,
481
+ y_samples: torch.Tensor,
482
+ offset: torch.Tensor,
483
+ x_ego: torch.Tensor,
484
+ y_ego: torch.Tensor,
485
+ dt: float,
486
+ unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
487
+ mask: Optional[torch.Tensor] = None,
488
+ ) -> torch.Tensor:
489
+ """Compute cost samples from predicted future trajectories
490
+
491
+ Args:
492
+ cost_function: Cost function to use
493
+ x: (batch_size, n_agents, num_steps, state_dim) normalized tensor of history
494
+ y_samples: (batch_size, n_agents, n_samples, num_steps_future, state_dim) normalized tensor of predicted
495
+ future trajectory samples
496
+ offset: (batch_size, n_agents, state_dim) offset position from ego
497
+ x_ego: (batch_size, 1, num_steps, state_dim) tensor of ego history
498
+ y_ego: (batch_size, 1, num_steps_future, state_dim) tensor of ego future trajectory
499
+ dt: time step in trajectories
500
+ unnormalizer: function that takes in a trajectory and an offset and that outputs the
501
+ unnormalized trajectory
502
+ mask: tensor indicating where to compute the cost
503
+ Returns:
504
+ torch.Tensor: (batch_size, n_agents, n_samples) cost tensor
505
+ """
506
+ x = unnormalizer(x, offset)
507
+ y_samples = unnormalizer(y_samples, offset)
508
+ if offset.shape[1] > 1:
509
+ x_ego = unnormalizer(x_ego, offset[:, 0:1])
510
+ y_ego = unnormalizer(y_ego, offset[:, 0:1])
511
+
512
+ min_dim = min(x.shape[-1], y_samples.shape[-1])
513
+ x = x[..., :min_dim]
514
+ y_samples = y_samples[..., :min_dim]
515
+ x_ego = x_ego[..., :min_dim]
516
+ y_ego = y_ego[..., :min_dim]
517
+ assert x_ego.ndim == y_ego.ndim
518
+ if y_samples.shape[-1] < 5:
519
+ v_samples = compute_v_from_x(x.unsqueeze(-3), y_samples, dt)
520
+ else:
521
+ v_samples = y_samples[..., 3:5]
522
+
523
+ if y_ego.shape[-1] < 5:
524
+ v_ego = compute_v_from_x(x_ego, y_ego, dt)
525
+ else:
526
+ v_ego = y_ego[..., 3:5]
527
+ if mask is not None:
528
+ mask = torch.cat(
529
+ (mask[..., 0:1], torch.logical_and(mask[..., 1:], mask[..., :-1])), -1
530
+ )
531
+
532
+ cost, _ = cost_function(
533
+ x1=y_ego.unsqueeze(-3),
534
+ x2=y_samples,
535
+ v1=v_ego.unsqueeze(-3),
536
+ v2=v_samples,
537
+ mask=mask,
538
+ )
539
+ return cost
risk_biased/utils/load_model.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import os
3
+ from typing import Optional, Tuple, Union
4
+ import warnings
5
+
6
+ from mmcv import Config
7
+ import torch
8
+ import wandb
9
+
10
+ from risk_biased.predictors.biased_predictor import (
11
+ LitTrajectoryPredictor,
12
+ LitTrajectoryPredictorParams,
13
+ )
14
+
15
+ from risk_biased.utils.config_argparse import config_argparse
16
+ from risk_biased.utils.cost import TTCCostParams
17
+ from risk_biased.utils.torch_utils import load_weights
18
+
19
+ from risk_biased.scene_dataset.loaders import SceneDataLoaders
20
+ from risk_biased.scene_dataset.scene import load_create_dataset
21
+
22
+ from risk_biased.utils.waymo_dataloader import WaymoDataloaders
23
+
24
+
25
+ def get_predictor(
26
+ config: Config, unnormalizer: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
27
+ ):
28
+ params = LitTrajectoryPredictorParams.from_config(config)
29
+ model_class = LitTrajectoryPredictor
30
+ ttc_params = TTCCostParams.from_config(config)
31
+ return model_class(params=params, unnormalizer=unnormalizer, cost_params=ttc_params)
32
+
33
+
34
+ def load_from_wandb_id(
35
+ log_id: str,
36
+ log_path: str,
37
+ entity: str,
38
+ project: str,
39
+ config: Optional[Config] = None,
40
+ load_last=False,
41
+ ) -> Tuple[Union[LitTrajectoryPredictor, LitTrajectoryPredictor], Config]:
42
+ """
43
+ Load a model using a wandb id code.
44
+ Args:
45
+ log_id: the wandb id code
46
+ log_path: the wandb log directory path
47
+ config: An optional configuration argument, use these settings if not None, use the settings from the log directory otherwise
48
+ load_last: An optional argumument, set to True to load the last checkpoint instead of the best one
49
+ Returns:
50
+ Predictor model and config file either loaded from the checkpoint or the one passed as argument.
51
+ """
52
+ list_matching = list(filter(lambda path: log_id in path, os.listdir(log_path)))
53
+ if len(list_matching) == 1:
54
+ list_ckpt = list(
55
+ filter(
56
+ lambda path: "epoch" in path and ".ckpt" in path,
57
+ os.listdir(os.path.join(log_path, list_matching[0], "files")),
58
+ )
59
+ )
60
+ if not load_last and len(list_ckpt) == 1:
61
+ print(f"Loading best model: {list_ckpt[0]}.")
62
+ checkpoint_path = os.path.join(
63
+ log_path, list_matching[0], "files", list_ckpt[0]
64
+ )
65
+ else:
66
+ print(f"Loading last checkpoint.")
67
+ checkpoint_path = os.path.join(
68
+ log_path, list_matching[0], "files", "last.ckpt"
69
+ )
70
+ config_path = os.path.join(
71
+ log_path, list_matching[0], "files", "learning_config.py"
72
+ )
73
+
74
+ if config is None:
75
+ config = config_argparse(config_path)
76
+ distant_model_type = None
77
+ else:
78
+ distant_config = config_argparse(config_path)
79
+ distant_model_type = distant_config.model_type
80
+ config["load_from"] = log_id
81
+
82
+ if config.model_type == "interaction_biased":
83
+ dataloaders = WaymoDataloaders(config)
84
+ else:
85
+ [data_train, data_val, data_test] = load_create_dataset(config)
86
+ dataloaders = SceneDataLoaders(
87
+ state_dim=config.state_dim,
88
+ num_steps=config.num_steps,
89
+ num_steps_future=config.num_steps_future,
90
+ batch_size=config.batch_size,
91
+ data_train=data_train,
92
+ data_val=data_val,
93
+ data_test=data_test,
94
+ num_workers=config.num_workers,
95
+ )
96
+
97
+ try:
98
+ if len(config.gpus):
99
+ map_location = "cpu"
100
+ else:
101
+ map_location = "gpu"
102
+ model = load_weights(
103
+ get_predictor(config, dataloaders.unnormalize_trajectory),
104
+ torch.load(checkpoint_path, map_location=map_location),
105
+ strict=True,
106
+ )
107
+ except RuntimeError:
108
+ raise RuntimeError(
109
+ f"The source model is of type {distant_model_type}."
110
+ + " It cannot be used to load the weights of the interaction biased model."
111
+ )
112
+
113
+ return model, dataloaders, config
114
+
115
+ else:
116
+ print("Trying to download logs from WandB...")
117
+ api = wandb.Api()
118
+ run = api.run(entity + "/" + project + "/" + log_id)
119
+ if run is not None:
120
+ checkpoint_path = os.path.join(
121
+ log_path, "downloaded_run-" + log_id, "files"
122
+ )
123
+ os.makedirs(checkpoint_path)
124
+ for file in run.files():
125
+ if file.name.endswith("ckpt") or file.name.endswith("config.py"):
126
+ file.download(checkpoint_path)
127
+ return load_from_wandb_id(
128
+ log_id, log_path, entity, project, config, load_last
129
+ )
130
+ else:
131
+ raise RuntimeError(
132
+ f"Error while loading checkpoint: Found {len(list_matching)} occurences of the given id {log_id} in the logs at {log_path}."
133
+ )
134
+
135
+
136
+ def load_from_config(cfg: Config):
137
+ """
138
+ This function loads the predictor model and the data depending on which one is selected in the config.
139
+ If a "load_from" field is not empty, then tries to load the pre-trained model from the checkpoint.
140
+ The matching config file is loaded
141
+
142
+ Args:
143
+ cfg : Configuration that defines the model to be loaded
144
+
145
+ Returns:
146
+ loaded model and a new version of the config that is compatible with the checkpoint model that it could be loaded from
147
+ """
148
+
149
+ log_path = os.path.join(cfg.log_path, "wandb")
150
+ ignored_keys = [
151
+ "project",
152
+ "dataset_parameters",
153
+ "load_from",
154
+ "force_config",
155
+ "load_last",
156
+ ]
157
+
158
+ if "load_from" in cfg.keys() and cfg.load_from != "" and cfg.load_from:
159
+ if "load_last" in cfg.keys():
160
+ load_last = cfg["load_last"]
161
+ else:
162
+ load_last = False
163
+ if cfg.force_config:
164
+ warnings.warn(
165
+ f"Using local configuration but loading from run {cfg.load_from}. Will fail if local configuration is not compatible."
166
+ )
167
+ predictor, dataloaders, config = load_from_wandb_id(
168
+ log_id=cfg.load_from,
169
+ log_path=log_path,
170
+ entity=cfg.entity,
171
+ project=cfg.project,
172
+ config=cfg,
173
+ load_last=load_last,
174
+ )
175
+ else:
176
+ predictor, dataloaders, config = load_from_wandb_id(
177
+ log_id=cfg.load_from,
178
+ log_path=log_path,
179
+ entity=cfg.entity,
180
+ project=cfg.project,
181
+ load_last=load_last,
182
+ )
183
+ difference = False
184
+ warning_message = ""
185
+ for key, item in cfg.items():
186
+ try:
187
+ if config[key] != item:
188
+ if not difference:
189
+ warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded.\n"
190
+ difference = True
191
+ if key not in ignored_keys:
192
+ warning_message += f" The value of '{key}' is now '{config[key]}' instead of '{item}'."
193
+ except KeyError:
194
+ if not difference:
195
+ warning_message += "When loading the model, the configuration was changed to match the configuration of the pre-trained model to be loaded."
196
+ difference = True
197
+ warning_message += f" The parameter '{key}' with value '{item}' does not exist for the model you are loading from, it is added."
198
+ config[key] = item
199
+ if warning_message != "":
200
+ warnings.warn(warning_message)
201
+ return predictor, dataloaders, config
202
+
203
+ else:
204
+ if cfg.model_type == "interaction_biased":
205
+ dataloaders = WaymoDataloaders(cfg)
206
+ else:
207
+ [data_train, data_val, data_test] = load_create_dataset(cfg)
208
+ dataloaders = SceneDataLoaders(
209
+ state_dim=cfg.state_dim,
210
+ num_steps=cfg.num_steps,
211
+ num_steps_future=cfg.num_steps_future,
212
+ batch_size=cfg.batch_size,
213
+ data_train=data_train,
214
+ data_val=data_val,
215
+ data_test=data_test,
216
+ num_workers=cfg.num_workers,
217
+ )
218
+
219
+ predictor = get_predictor(cfg, dataloaders.unnormalize_trajectory)
220
+ return predictor, dataloaders, cfg
risk_biased/utils/loss.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.distributions import MultivariateNormal
6
+
7
+
8
+ def reconstruction_loss(
9
+ pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
10
+ ):
11
+ """
12
+ pred (Tensor): (..., time, [x,y,(a),(vx,vy)])
13
+ truth (Tensor): (..., time, [x,y,(a),(vx,vy)])
14
+ mask_loss (Tensor): (..., time) Defaults to None.
15
+ """
16
+ min_feat_shape = min(pred.shape[-1], truth.shape[-1])
17
+ if min_feat_shape == 3:
18
+ assert pred.shape[-1] == truth.shape[-1]
19
+ return reconstruction_loss(
20
+ pred[..., :2], truth[..., :2], mask_loss
21
+ ) + reconstruction_loss(
22
+ torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1),
23
+ torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1),
24
+ mask_loss,
25
+ )
26
+ elif min_feat_shape >= 5:
27
+ assert pred.shape[-1] <= truth.shape[-1]
28
+ v_norm = torch.sum(torch.square(truth[..., 3:5]), -1, keepdim=True)
29
+ v_mask = v_norm > 1
30
+ return (
31
+ reconstruction_loss(pred[..., :2], truth[..., :2], mask_loss)
32
+ + reconstruction_loss(
33
+ torch.stack([torch.cos(pred[..., 2]), torch.sin(pred[..., 2])], -1)
34
+ * v_mask,
35
+ torch.stack([torch.cos(truth[..., 2]), torch.sin(truth[..., 2])], -1)
36
+ * v_mask,
37
+ mask_loss,
38
+ )
39
+ + reconstruction_loss(pred[..., 3:5], truth[..., 3:5], mask_loss)
40
+ )
41
+ elif min_feat_shape == 2:
42
+ if mask_loss is None:
43
+ return torch.mean(
44
+ torch.sqrt(
45
+ torch.sum(
46
+ torch.square(pred[..., :2] - truth[..., :2]), -1
47
+ ).clamp_min(1e-6)
48
+ )
49
+ )
50
+ else:
51
+ assert mask_loss.any()
52
+ mask_loss = mask_loss.float()
53
+ return torch.sum(
54
+ torch.sqrt(
55
+ torch.sum(
56
+ torch.square(pred[..., :2] - truth[..., :2]), -1
57
+ ).clamp_min(1e-6)
58
+ )
59
+ * mask_loss
60
+ ) / torch.sum(mask_loss).clamp_min(1)
61
+
62
+
63
+ def map_penalized_reconstruction_loss(
64
+ pred: torch.Tensor,
65
+ truth: torch.Tensor,
66
+ map: torch.Tensor,
67
+ mask_map: torch.Tensor,
68
+ mask_loss: Optional[torch.Tensor] = None,
69
+ map_importance: float = 0.1,
70
+ ):
71
+ """
72
+ pred (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
73
+ truth (Tensor): (batch_size, num_agents, time, [x,y,(a),(vx,vy)])
74
+ map (Tensor): (batch_size, num_objects, object_sequence_length, [x, y, ...])
75
+ mask_map (Tensor): (...)
76
+ mask_loss (Tensor): (..., time) Defaults to None.
77
+
78
+ """
79
+ # b, a, o, s, f b, a, o, t, s, f
80
+ map_distance, _ = (
81
+ (map[:, None, :, :, :2] - pred[:, :, None, -1, None, :2])
82
+ .square()
83
+ .sum(-1)
84
+ .min(2)
85
+ )
86
+ map_distance = map_distance.sqrt().clamp(0.5, 3)
87
+ if mask_map is not None:
88
+ map_loss = (map_distance * mask_loss[..., -1:]).sum() / mask_loss[..., -1].sum()
89
+ else:
90
+ map_loss = map_distance.mean()
91
+
92
+ rec_loss = reconstruction_loss(pred, truth, mask_loss)
93
+
94
+ return rec_loss + map_importance * map_loss
95
+
96
+
97
+ def cce_loss_with_logits(pred_logits: torch.Tensor, truth: torch.Tensor):
98
+ pred_log = pred_logits.log_softmax(-1)
99
+ return -(pred_log * truth).sum(-1).mean()
100
+
101
+
102
+ def risk_loss_function(
103
+ pred: torch.Tensor,
104
+ truth: torch.Tensor,
105
+ mask: torch.Tensor,
106
+ factor: float = 100.0,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Loss function for the risk comparison. This is assymetric because it is preferred that the model over-estimates
110
+ the risk rather than under-estimate it.
111
+ Args:
112
+ pred: (same_shape) The predicted risks
113
+ truth: (same_shape) The reference risks to match
114
+ mask: (same_shape) A mask with 1 where the loss should be computed and 0 elsewhere.
115
+ approximate_mean_error: An approximation of the mean error obtained after training. The lower this value,
116
+ the greater the intensity of the assymetry.
117
+ Returns:
118
+ Scalar loss value
119
+ """
120
+ error = pred - truth
121
+ error = error * factor
122
+ error = torch.where(error > 1, (error + 1e-6).log(), error.abs())
123
+ error = (error * mask).sum() / mask.sum()
124
+ return error
risk_biased/utils/metrics.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+
6
+ def FDE(
7
+ pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
8
+ ):
9
+ """
10
+ pred (Tensor): (..., time, xy)
11
+ truth (Tensor): (..., time, xy)
12
+ mask_loss (Tensor): (..., time) Defaults to None.
13
+ """
14
+ if mask_loss is None:
15
+ return torch.mean(
16
+ torch.sqrt(
17
+ torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
18
+ )
19
+ )
20
+ else:
21
+ mask_loss = mask_loss.float()
22
+ return torch.sum(
23
+ torch.sqrt(
24
+ torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
25
+ )
26
+ * mask_loss[..., -1]
27
+ ) / torch.sum(mask_loss[..., -1]).clamp_min(1)
28
+
29
+
30
+ def ADE(
31
+ pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
32
+ ):
33
+ """
34
+ pred (Tensor): (..., time, xy)
35
+ truth (Tensor): (..., time, xy)
36
+ mask_loss (Tensor): (..., time) Defaults to None.
37
+ """
38
+ if mask_loss is None:
39
+ return torch.mean(
40
+ torch.sqrt(
41
+ torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
42
+ )
43
+ )
44
+ else:
45
+ mask_loss = mask_loss.float()
46
+ return torch.sum(
47
+ torch.sqrt(
48
+ torch.sum(torch.square(pred[..., :, :2] - truth[..., :, :2]), -1)
49
+ )
50
+ * mask_loss
51
+ ) / torch.sum(mask_loss).clamp_min(1)
52
+
53
+
54
+ def minFDE(
55
+ pred: torch.Tensor, truth: torch.Tensor, mask_loss: Optional[torch.Tensor] = None
56
+ ):
57
+ """
58
+ pred (Tensor): (..., n_samples, time, xy)
59
+ truth (Tensor): (..., time, xy)
60
+ mask_loss (Tensor): (..., time) Defaults to None.
61
+ """
62
+ if mask_loss is None:
63
+ min_distances, _ = torch.min(
64
+ torch.sqrt(
65
+ torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
66
+ ),
67
+ -1,
68
+ )
69
+ return torch.mean(min_distances)
70
+ else:
71
+ mask_loss = mask_loss[..., -1].float()
72
+ final_distances = torch.sqrt(
73
+ torch.sum(torch.square(pred[..., -1, :2] - truth[..., -1, :2]), -1)
74
+ )
75
+ max_final_distance = torch.max(final_distances * mask_loss)
76
+ min_distances, _ = torch.min(
77
+ final_distances + max_final_distance * (1 - mask_loss), -1
78
+ )
79
+ return torch.sum(min_distances * mask_loss.any(-1)) / torch.sum(
80
+ mask_loss.any(-1)
81
+ ).clamp_min(1)
risk_biased/utils/planner_utils.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Optional, Tuple
5
+ from numpy import isin
6
+
7
+ import torch
8
+
9
+ from risk_biased.mpc_planner.planner_cost import TrackingCost
10
+ from risk_biased.utils.cost import BaseCostTorch
11
+ from risk_biased.utils.risk import AbstractMonteCarloRiskEstimator
12
+
13
+
14
+ def get_rotation_matrix(angle, device):
15
+ c = torch.cos(angle)
16
+ s = torch.sin(angle)
17
+ rot_matrix = torch.stack(
18
+ (torch.stack((c, s), -1), torch.stack((-s, c), -1)), -1
19
+ ).to(device)
20
+ return rot_matrix
21
+
22
+
23
+ class AbstractState(ABC):
24
+ """
25
+ State representation using an underlying tensor. Position, Velocity, and Angle can be accessed.
26
+ """
27
+
28
+ @property
29
+ @abstractmethod
30
+ def position(self) -> torch.Tensor:
31
+ """Extract position information from the state tensor
32
+
33
+ Returns:
34
+ position_tensor of size (..., 2)
35
+ """
36
+
37
+ @property
38
+ @abstractmethod
39
+ def velocity(self) -> torch.Tensor:
40
+ """Extract velocity information from the state tensor
41
+
42
+ Returns:
43
+ velocity_tensor of size (..., 2)
44
+ """
45
+
46
+ @property
47
+ @abstractmethod
48
+ def angle(self) -> torch.Tensor:
49
+ """Extract velocity information from the state tensor
50
+
51
+ Returns:
52
+ velocity_tensor of size (..., 1)
53
+ """
54
+
55
+ @abstractmethod
56
+ def get_states(self, dim: int) -> torch.Tensor:
57
+ """Return the underlying states tensor with dim 2, 4 or 5 ([x, y], [x, y, vx, vy], or [x, y, angle, vx, vy])."""
58
+
59
+ @abstractmethod
60
+ def rotate(self, angle: float, in_place: bool) -> AbstractState:
61
+ """Rotate the state by the given angle
62
+ Args:
63
+ angle: in radiants
64
+ in_place: wether to change the object itself or return a rotated copy
65
+ Returns:
66
+ rotated self or rotated copy of self
67
+ """
68
+
69
+ @abstractmethod
70
+ def translate(self, translation: torch.Tensor, in_place: bool) -> AbstractState:
71
+ """Translate the state by the given tranlation
72
+ Args:
73
+ translation: translation vector in 2 dimensions
74
+ in_place: wether to change the object itself or return a rotated copy
75
+ """
76
+
77
+ # Define overloading operators to behave as a tensor for some operations
78
+ def __getitem__(self, key) -> AbstractState:
79
+ """
80
+ Use get item on the underlying tensor to get the item at the given key.
81
+ Allways returns a velocity state so that if the underlying time sequence is reduced to one step, the velocity is still accessible.
82
+ """
83
+ if isinstance(key, int):
84
+ key = (key, Ellipsis, slice(None, None, None))
85
+ elif Ellipsis not in key:
86
+ key = (*key, Ellipsis, slice(None, None, None))
87
+ else:
88
+ key = (*key, slice(None, None, None))
89
+
90
+ return to_state(
91
+ torch.cat(
92
+ (
93
+ self.position[key],
94
+ self.velocity[key],
95
+ ),
96
+ dim=-1,
97
+ ),
98
+ self.dt,
99
+ )
100
+
101
+ @property
102
+ def shape(self):
103
+ return self._states.shape[:-1]
104
+
105
+
106
+ def to_state(in_tensor: torch.Tensor, dt: float) -> AbstractState:
107
+ if in_tensor.shape[-1] == 2:
108
+ return PositionSequenceState(in_tensor, dt)
109
+ elif in_tensor.shape[-1] == 4:
110
+ return PositionVelocityState(in_tensor, dt)
111
+ else:
112
+ assert in_tensor.shape[-1] > 4
113
+ return PositionAngleVelocityState(in_tensor, dt)
114
+
115
+
116
+ class PositionSequenceState(AbstractState):
117
+ """
118
+ State representation with an underlying tensor defining only positions.
119
+ """
120
+
121
+ def __init__(self, states: torch.Tensor, dt: float) -> None:
122
+ super().__init__()
123
+ assert (
124
+ states.shape[-1] == 2
125
+ ) # Check that the input tensor defines only the position
126
+ assert (
127
+ states.ndim > 1 and states.shape[-2] > 1
128
+ ) # Check that the input tensor defines a sequence of positions (otherwise velocity cannot be computed)
129
+ self.dt = dt
130
+ self._states = states.clone()
131
+
132
+ @property
133
+ def position(self) -> torch.Tensor:
134
+ return self._states
135
+
136
+ @property
137
+ def velocity(self) -> torch.Tensor:
138
+ vel = (self._states[..., 1:, :] - self._states[..., :-1, :]) / self.dt
139
+ vel = torch.cat((vel[..., 0:1, :], vel), dim=-2)
140
+ return vel.clone()
141
+
142
+ @property
143
+ def angle(self) -> torch.Tensor:
144
+ vel = self.velocity
145
+ angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1])
146
+ return angle
147
+
148
+ def get_states(self, dim: int = 2) -> torch.Tensor:
149
+ if dim == 2:
150
+ return self._states.clone()
151
+ elif dim == 4:
152
+ return torch.cat((self._states.clone(), self.velocity), dim=-1)
153
+ elif dim == 5:
154
+ return torch.cat((self._states.clone(), self.angle, self.velocity), dim=-1)
155
+ else:
156
+ raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")
157
+
158
+ def rotate(self, angle: float, in_place: bool = False) -> PositionSequenceState:
159
+ """Rotate the state by the given angle in radiants"""
160
+ rot_matrix = get_rotation_matrix(angle, self._states.device)
161
+ if in_place:
162
+ self._states = (rot_matrix @ self._states.unsqueeze(-1)).squeeze(-1)
163
+ return self
164
+ else:
165
+ return to_state(
166
+ (rot_matrix @ self._states.unsqueeze(-1).clone()).squeeze(-1), self.dt
167
+ )
168
+
169
+ def translate(
170
+ self, translation: torch.Tensor, in_place: bool = False
171
+ ) -> PositionSequenceState:
172
+ """Translate the state by the given tranlation"""
173
+ if in_place:
174
+ self._states[..., :2] += translation.expand_as(self._states[..., :2])
175
+ return self
176
+ else:
177
+ return to_state(
178
+ self._states[..., :2].clone()
179
+ + translation.expand_as(self._states[..., :2]),
180
+ self.dt,
181
+ )
182
+
183
+
184
+ class PositionVelocityState(AbstractState):
185
+ """
186
+ State representation with an underlying tensor defining position and velocity.
187
+ """
188
+
189
+ def __init__(self, states: torch.Tensor, dt) -> None:
190
+ super().__init__()
191
+ assert states.shape[-1] == 4
192
+ self._states = states.clone()
193
+ self.dt = dt
194
+
195
+ @property
196
+ def position(self) -> torch.Tensor:
197
+ return self._states[..., :2]
198
+
199
+ @property
200
+ def velocity(self) -> torch.Tensor:
201
+ return self._states[..., 2:4]
202
+
203
+ @property
204
+ def angle(self) -> torch.Tensor:
205
+ vel = self.velocity
206
+ angle = torch.arctan2(vel[..., 1:2], vel[..., 0:1])
207
+ return angle
208
+
209
+ def get_states(self, dim: int = 4) -> torch.Tensor:
210
+ if dim == 2:
211
+ return self._states[..., :2].clone()
212
+ elif dim == 4:
213
+ return self._states.clone()
214
+ elif dim == 5:
215
+ return torch.cat(
216
+ (
217
+ self._states[..., :2].clone(),
218
+ self.angle,
219
+ self._states[..., 2:].clone(),
220
+ ),
221
+ dim=-1,
222
+ )
223
+ else:
224
+ raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")
225
+
226
+ def rotate(
227
+ self, angle: torch.Tensor, in_place: bool = False
228
+ ) -> PositionVelocityState:
229
+ """Rotate the state by the given angle in radiants"""
230
+ rot_matrix = get_rotation_matrix(angle, self._states.device)
231
+ rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1)
232
+ rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1)
233
+ if in_place:
234
+ self._states = torch.cat((rotated_pos, rotated_vel), dim=-1)
235
+ return self
236
+ else:
237
+ return to_state(torch.cat((rotated_pos, rotated_vel), dim=-1), self.dt)
238
+
239
+ def translate(
240
+ self, translation: torch.Tensor, in_place: bool = False
241
+ ) -> PositionVelocityState:
242
+ """Translate the state by the given tranlation"""
243
+ if in_place:
244
+ self._states[..., :2] += translation.expand_as(self._states[..., :2])
245
+ return self
246
+ else:
247
+ return to_state(
248
+ torch.cat(
249
+ (
250
+ self._states[..., :2].clone()
251
+ + translation.expand_as(self._states[..., :2]),
252
+ self._states[..., 2:].clone(),
253
+ ),
254
+ dim=-1,
255
+ ),
256
+ self.dt,
257
+ )
258
+
259
+
260
+ class PositionAngleVelocityState(AbstractState):
261
+ """
262
+ State representation with an underlying tensor representing position angle and velocity.
263
+ """
264
+
265
+ def __init__(self, states: torch.Tensor, dt: float) -> None:
266
+ super().__init__()
267
+ assert states.shape[-1] == 5
268
+ self._states = states.clone()
269
+ self.dt = dt
270
+
271
+ @property
272
+ def position(self) -> torch.Tensor:
273
+ return self._states[..., :2].clone()
274
+
275
+ @property
276
+ def velocity(self) -> torch.Tensor:
277
+ return self._states[..., 3:5].clone()
278
+
279
+ @property
280
+ def angle(self) -> torch.Tensor:
281
+ return self._states[..., 2:3].clone()
282
+
283
+ def get_states(self, dim: int = 5) -> torch.Tensor:
284
+ if dim == 2:
285
+ return self._states[..., :2].clone()
286
+ elif dim == 4:
287
+ return torch.cat(
288
+ (self._states[..., :2].clone(), self._states[..., 3:].clone()), dim=-1
289
+ )
290
+ elif dim == 5:
291
+ return self._states.clone()
292
+ else:
293
+ raise RuntimeError(f"State dimension must be either 2, 4, or 5. Got {dim}")
294
+
295
+ def rotate(
296
+ self, angle: float, in_place: bool = False
297
+ ) -> PositionAngleVelocityState:
298
+ """Rotate the state by the given angle in radiants"""
299
+ rot_matrix = get_rotation_matrix(angle, self._states.device)
300
+ rotated_pos = (rot_matrix @ self.position.unsqueeze(-1)).squeeze(-1)
301
+ rotated_angle = self.angle + angle
302
+ rotated_vel = (rot_matrix @ self.velocity.unsqueeze(-1)).squeeze(-1)
303
+ if in_place:
304
+ self._states = torch.cat(rotated_pos, rotated_angle, rotated_vel, -1)
305
+ return self
306
+ else:
307
+ return to_state(
308
+ torch.cat(rotated_pos, rotated_angle, rotated_vel, -1), self.dt
309
+ )
310
+
311
+ def translate(
312
+ self, translation: torch.Tensor, in_place: bool = False
313
+ ) -> PositionAngleVelocityState:
314
+ """Translate the state by the given tranlation"""
315
+ if in_place:
316
+ self._states[..., :2] += translation.expand_as(self._states[..., :2])
317
+ return self
318
+ else:
319
+ return to_state(
320
+ torch.cat(
321
+ (
322
+ self._states[..., :2]
323
+ + translation.expand_as(self._states[..., :2]),
324
+ self._states[..., 2:],
325
+ ),
326
+ dim=-1,
327
+ ),
328
+ self.dt,
329
+ )
330
+
331
+
332
+ def get_interaction_cost(
333
+ ego_state_future: AbstractState,
334
+ ado_state_future_samples: AbstractState,
335
+ interaction_cost_function: BaseCostTorch,
336
+ ) -> torch.Tensor:
337
+ """Computes interaction cost samples from predicted ado future trajectories and a batch of ego
338
+ future trajectories
339
+
340
+ Args:
341
+ ego_state_future: ((num_control_samples), num_agents, num_steps_future) ego state future
342
+ future trajectory
343
+ ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
344
+ predicted ado state trajectory samples
345
+ interaction_cost_function: interaction cost function between ego and (stochastic) ado
346
+ dt: time differential between two discrete timesteps in seconds
347
+
348
+ Returns:
349
+ (num_control_samples, num_agents, num_prediction_samples) interaction cost tensor
350
+ """
351
+ if len(ego_state_future.shape) == 2:
352
+ x_ego = ego_state_future.position.unsqueeze(0)
353
+ v_ego = ego_state_future.velocity.unsqueeze(0)
354
+ else:
355
+ x_ego = ego_state_future.position
356
+ v_ego = ego_state_future.velocity
357
+
358
+ num_control_samples = ego_state_future.shape[0]
359
+ ado_position_future_samples = ado_state_future_samples.position.unsqueeze(0).expand(
360
+ num_control_samples, -1, -1, -1, -1
361
+ )
362
+
363
+ v_samples = ado_state_future_samples.velocity.unsqueeze(0).expand(
364
+ num_control_samples, -1, -1, -1, -1
365
+ )
366
+
367
+ interaction_cost, _ = interaction_cost_function(
368
+ x1=x_ego.unsqueeze(1),
369
+ x2=ado_position_future_samples,
370
+ v1=v_ego.unsqueeze(1),
371
+ v2=v_samples,
372
+ )
373
+ return interaction_cost.permute(0, 2, 1)
374
+
375
+
376
+ def evaluate_risk(
377
+ risk_level: float,
378
+ cost: torch.Tensor,
379
+ weights: torch.Tensor,
380
+ risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
381
+ ) -> torch.Tensor:
382
+ """Returns a risk tensor given costs and optionally a risk level
383
+
384
+ Args:
385
+ risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation will be
386
+ returned. Defaults to 0.0.
387
+ cost: (num_control_samples, num_agents, num_prediction_samples) cost tensor
388
+ weights: (num_control_samples, num_agents, num_prediction_samples) probability weight of the cost tensor
389
+ risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None.
390
+
391
+ Returns:
392
+ (num_control_samples, num_agents) risk tensor
393
+ """
394
+ num_control_samples, num_agents, _ = cost.shape
395
+
396
+ if risk_level == 0.0:
397
+ risk = cost.mean(dim=-1)
398
+ else:
399
+ assert risk_estimator is not None, "no risk estimator is specified"
400
+ risk = risk_estimator(
401
+ risk_level * torch.ones(num_control_samples, num_agents),
402
+ cost,
403
+ weights=weights,
404
+ )
405
+ return risk
406
+
407
+
408
+ def evaluate_control_sequence(
409
+ control_sequence: torch.Tensor,
410
+ dynamics_model,
411
+ ego_state_history: AbstractState,
412
+ ego_state_target_trajectory: AbstractState,
413
+ ado_state_future_samples: AbstractState,
414
+ sample_weights: torch.Tensor,
415
+ interaction_cost_function: BaseCostTorch,
416
+ tracking_cost_function: TrackingCost,
417
+ risk_level: float = 0.0,
418
+ risk_estimator: Optional[AbstractMonteCarloRiskEstimator] = None,
419
+ ) -> Tuple[float, float]:
420
+ """Returns the risk and tracking cost evaluation of the given control sequence
421
+
422
+ Args:
423
+ control_sequence: (num_steps_future, control_dim) tensor of control sequence
424
+ dynamics_model: dynamics model for control
425
+ ego_state_target_trajectory: (num_steps_future) tensor of ego target
426
+ state trajectory
427
+ ado_state_future_samples: (num_prediction_samples, num_agents, num_steps_future)
428
+ of predicted ado trajectory samples states
429
+ sample_weights: (num_prediction_samples, num_agents) tensor of probability weights of the samples
430
+ intraction_cost_function: interaction cost function between ego and (stochastic) ado
431
+ tracking_cost_function: deterministic tracking cost that does not involve ado
432
+ risk_level: risk_level (optional): a risk-level float. If 0.0, risk-neutral expectation
433
+ is used. Defaults to 0.0.
434
+ risk_estimator (optional): a Monte Carlo risk estimator. Defaults to None.
435
+
436
+ Returns:
437
+ tuple of (interaction risk, tracking_cost)
438
+ """
439
+ ego_state_current = ego_state_history[..., -1]
440
+ ego_state_future = dynamics_model.simulate(ego_state_current, control_sequence)
441
+ # state starts with x, y, angle, vx, vy
442
+ tracking_cost = tracking_cost_function(
443
+ ego_state_future.position,
444
+ ego_state_target_trajectory.position,
445
+ ego_state_target_trajectory.velocity,
446
+ )
447
+
448
+ interaction_cost = get_interaction_cost(
449
+ ego_state_future,
450
+ ado_state_future_samples,
451
+ interaction_cost_function,
452
+ )
453
+
454
+ interaction_risk = evaluate_risk(
455
+ risk_level,
456
+ interaction_cost,
457
+ sample_weights.permute(1, 0).unsqueeze(0).expand_as(interaction_cost),
458
+ risk_estimator,
459
+ )
460
+
461
+ # TODO: averaging over agents but we might want to reduce a different way
462
+ return (interaction_risk.mean().item(), tracking_cost.mean().item())
risk_biased/utils/risk.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ import warnings
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+
10
+ class AbstractMonteCarloRiskEstimator(ABC):
11
+ """Abstract class for Monte Carlo estimation of risk objectives"""
12
+
13
+ @abstractmethod
14
+ def __call__(self, risk_level: Tensor, cost: Tensor) -> Tensor:
15
+ """Computes and returns the risk objective estimated on the cost tensor
16
+
17
+ Args:
18
+ risk_level: (batch_size,) tensor of risk-level at which the risk objective is computed
19
+ cost: (batch_size, n_samples) tensor of cost samples
20
+ Returns:
21
+ risk tensor of size (batch_size,)
22
+ """
23
+
24
+
25
+ class EntropicRiskEstimator(AbstractMonteCarloRiskEstimator):
26
+ """Monte Carlo estimator for the entropic risk objective.
27
+ This estimator computes the entropic risk as 1/risk_level * log( mean( exp(risk_level * cost), 1))
28
+ However, this is unstable.
29
+ When risk_level is large, the logsumexp trick is used.
30
+ When risk_level is small, it computes entropic_risk for small values of risk_level as the second order Taylor expansion instead.
31
+
32
+ Args:
33
+ eps: Risk-level threshold to switch between logsumexp and Taylor expansion. Defaults to 1e-4.
34
+ """
35
+
36
+ def __init__(self, eps: float = 1e-4) -> None:
37
+ self.eps = eps
38
+
39
+ def __call__(self, risk_level: Tensor, cost: Tensor, weights: Tensor) -> Tensor:
40
+ """Computes and returns the entropic risk estimated on the cost tensor
41
+
42
+ Args:
43
+ risk_level: (batch_size, n_agents,) tensor of risk-level at which the risk objective is computed
44
+ cost: (batch_size, n_agents, n_samples) cost tensor
45
+ weights: (batch_size, n_agents, n_samples) tensor of weights for the cost samples
46
+ Returns:
47
+ entropic risk tensor of size (batch_size,)
48
+ """
49
+ weights = weights / weights.sum(dim=-1, keepdim=True)
50
+ batch_size, n_agents, n_samples = cost.shape
51
+ entropic_risk_cost_large_sigma = (
52
+ ((risk_level.view(batch_size, n_agents, 1) * cost).exp() * weights)
53
+ .sum(-1)
54
+ .log()
55
+ ) / risk_level
56
+
57
+ mean = (cost * weights).sum(dim=-1)
58
+ var = (cost**2 * weights).sum(dim=-1) - mean**2
59
+
60
+ var, mean = torch.var_mean(cost, -1)
61
+ entropic_risk_cost_small_sigma = mean + 0.5 * risk_level * var
62
+
63
+ return torch.where(
64
+ torch.abs(risk_level) > self.eps,
65
+ entropic_risk_cost_large_sigma,
66
+ entropic_risk_cost_small_sigma,
67
+ )
68
+
69
+
70
+ class CVaREstimator(AbstractMonteCarloRiskEstimator):
71
+ """Monte Carlo estimator for the conditional value-at-risk objective.
72
+ This estimator is proposed in the following references, and shown to be consistent.
73
+ - Hong et al. (2014), "Monte Carlo Methods for Value-at-Risk and Conditional Value-at-Risk: A Review"
74
+ - Traindade et al. (2007), "Financial prediction with constrained tail risk"
75
+ When risk_level is larger than 1 - eps, it falls back to the max operator
76
+
77
+ Args:
78
+ Args:
79
+ eps: Risk-level threshold to switch between CVaR and Max. Defaults to 1e-4.
80
+ """
81
+
82
+ def __init__(self, eps: float = 1e-4) -> None:
83
+ self.eps = eps
84
+
85
+ def __call__(self, risk_level: Tensor, cost: Tensor, weights: Tensor) -> Tensor:
86
+ """Computes and returns the conditional value-at-risk estimated on the cost tensor
87
+
88
+ Args:
89
+ risk_level: (batch_size, n_agents) tensor of risk-level in [0, 1] at which the CVaR risk is computed
90
+ cost: (batch_size, n_agents, n_samples) cost tensor
91
+ weights: (batch_size, n_agents, n_samples) tensor of weights for the cost samples
92
+
93
+ Returns:
94
+ conditional value-at-risk tensor of size (batch_size, n_agents)
95
+ """
96
+ assert risk_level.shape[0] == cost.shape[0]
97
+ assert risk_level.shape[1] == cost.shape[1]
98
+ if weights is None:
99
+ weights = torch.ones_like(cost) / cost.shape[-1]
100
+ else:
101
+ weights = weights / weights.sum(dim=-1, keepdim=True)
102
+ if not (torch.all(0.0 <= risk_level) and torch.all(risk_level <= 1.0)):
103
+ warnings.warn(
104
+ "risk_level is defined only between 0.0 and 1.0 for CVaR. Exceeded values will be clamped."
105
+ )
106
+ risk_level = torch.clamp(risk_level, min=0.0, max=1.0)
107
+
108
+ cvar_risk_high = cost.max(dim=-1).values
109
+
110
+ sorted_indices = torch.argsort(cost, dim=-1)
111
+ # cost_sorted = cost.sort(dim=-1, descending=False).values
112
+ cost_sorted = torch.gather(cost, -1, sorted_indices)
113
+ weights_sorted = torch.gather(weights, -1, sorted_indices)
114
+ idx_to_choose = torch.argmax(
115
+ (weights_sorted.cumsum(dim=-1) >= risk_level.unsqueeze(-1)).float(), -1
116
+ )
117
+
118
+ value_at_risk_mc = cost_sorted.gather(-1, idx_to_choose.unsqueeze(-1)).squeeze(
119
+ -1
120
+ )
121
+
122
+ # weights_at_risk_mc = 1 - weights_sorted.cumsum(-1).gather(
123
+ # -1, idx_to_choose.unsqueeze(-1)
124
+ # ).squeeze(-1)
125
+ # cvar_risk_mc = value_at_risk_mc + (
126
+ # (torch.relu(cost - value_at_risk_mc.unsqueeze(-1)) * weights).sum(dim=-1)
127
+ # / weights_at_risk_mc
128
+ # )
129
+ # cvar = torch.where(weights_at_risk_mc < self.eps, cvar_risk_high, cvar_risk_mc)
130
+
131
+ cvar_risk_mc = value_at_risk_mc + 1 / (1 - risk_level) * (
132
+ (torch.relu(cost - value_at_risk_mc.unsqueeze(-1)) * weights).sum(dim=-1)
133
+ )
134
+ cvar = torch.where(risk_level > 1 - self.eps, cvar_risk_high, cvar_risk_mc)
135
+ return cvar
136
+
137
+
138
+ def get_risk_estimator(estimator_params: dict) -> AbstractMonteCarloRiskEstimator:
139
+ """Function that returns the Monte Carlo risk estimator hat matches the given parameters.
140
+ Tries to give a comprehensive feedback if the parameters are not recognized and raise an error.
141
+
142
+ Args:
143
+ Risk estimator should be one of the following types (with different parameter values as desired) :
144
+ {"type": "entropic", "eps": 1e-4},
145
+ {"type": "cvar", "eps": 1e-4}
146
+
147
+ Raises:
148
+ RuntimeError: If the given parameter dictionary does not match one of the expected formats, raise a comprehensive error.
149
+
150
+ Returns:
151
+ A risk estimator matching the given parameters.
152
+ """
153
+ known_types = ["entropic", "cvar"]
154
+ try:
155
+ if estimator_params["type"].lower() == "entropic":
156
+ expected_params = inspect.getfullargspec(EntropicRiskEstimator)[0][1:]
157
+ return EntropicRiskEstimator(estimator_params["eps"])
158
+ elif estimator_params["type"].lower() == "cvar":
159
+ expected_params = inspect.getfullargspec(CVaREstimator)[0][1:]
160
+ return CVaREstimator(estimator_params["eps"])
161
+ else:
162
+ raise RuntimeError(
163
+ f"Risk estimator '{estimator_params['type']}' is unknown. It should be one of {known_types}."
164
+ )
165
+ except KeyError:
166
+ if "type" in estimator_params:
167
+ raise RuntimeError(
168
+ f"""The estimator '{estimator_params['type']}' is known but the given parameters
169
+ {estimator_params} do not match the expected parameters {expected_params}."""
170
+ )
171
+ else:
172
+ raise RuntimeError(
173
+ f"""The given estimator parameters {estimator_params} do not define the estimator
174
+ type in the field 'type'. Please add a field 'type' and set it to one of the
175
+ handeled types: {known_types}."""
176
+ )
177
+
178
+
179
+ class AbstractRiskLevelSampler(ABC):
180
+ """Abstract class for a risk-level sampler for training and evaluating risk-biased predictors"""
181
+
182
+ @abstractmethod
183
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
184
+ """Returns a tensor of size batch_size with sampled risk-level values
185
+
186
+ Args:
187
+ batch_size: number of elements in the out tensor
188
+ device: device of the output tensor
189
+
190
+ Returns:
191
+ A tensor of shape(batch_size,) filled with sampled risk values
192
+ """
193
+
194
+ @abstractmethod
195
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
196
+ """Returns a tensor of size batch_size with high values of risk.
197
+
198
+ Args:
199
+ batch_size: number of elements in the out tensor
200
+ device: device of the output tensor
201
+
202
+ Returns:
203
+ A tensor of shape (batchc_size,) filled with the highest possible risk-level
204
+ """
205
+
206
+
207
+ class UniformRiskLevelSampler(AbstractRiskLevelSampler):
208
+ """Risk-level sampler with a uniform distribution
209
+
210
+ Args:
211
+ min: minimum risk-level
212
+ max: maximum risk-level
213
+ """
214
+
215
+ def __init__(self, min: int, max: int) -> None:
216
+ self.min = min
217
+ self.max = max
218
+
219
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
220
+ return torch.rand(batch_size, device=device) * (self.max - self.min) + self.min
221
+
222
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
223
+ return torch.ones(batch_size, device=device) * self.max
224
+
225
+
226
+ class NormalRiskLevelSampler(AbstractRiskLevelSampler):
227
+ """Risk-level sampler with a normal distribution
228
+
229
+ Args:
230
+ mean: average risk-level
231
+ sigma: standard deviation of the sampler
232
+ """
233
+
234
+ def __init__(self, mean: int, sigma: int) -> None:
235
+ self.mean = mean
236
+ self.sigma = sigma
237
+
238
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
239
+ return torch.randn(batch_size, device=device) * self.sigma + self.mean
240
+
241
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
242
+ return torch.ones(batch_size, device=device) * self.sigma * 3
243
+
244
+
245
+ class BernoulliRiskLevelSampler(AbstractRiskLevelSampler):
246
+ """Risk-level sampler with a scaled Bernoulli distribution
247
+
248
+ Args:
249
+ min: minimum risk-level
250
+ max: maximum risk-level
251
+ p: Bernoulli parameter
252
+ """
253
+
254
+ def __init__(self, min: int, max: int, p: int) -> None:
255
+ self.min = min
256
+ self.max = max
257
+ self.p = p
258
+
259
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
260
+ return (
261
+ torch.bernoulli(torch.ones(batch_size, device=device) * self.p)
262
+ * (self.max - self.min)
263
+ + self.min
264
+ )
265
+
266
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
267
+ return torch.ones(batch_size, device=device) * self.max
268
+
269
+
270
+ class BetaRiskLevelSampler(AbstractRiskLevelSampler):
271
+ """Risk-level sampler with a scaled Beta distribution
272
+
273
+ Distribution properties:
274
+ mean = alpha*(max-min)/(alpha + beta) + min
275
+ mode = (alpha-1)*(max-min)/(alpha + beta - 2) + min
276
+ variance = alpha*beta*(max-min)**2/((alpha+beta)**2*(alpha+beta+1))
277
+
278
+ Args:
279
+ min: minimum risk-level
280
+ max: maximum risk-level
281
+ alpha: First distribution parameter
282
+ beta: Second distribution parameter
283
+ """
284
+
285
+ def __init__(self, min: int, max: int, alpha: float, beta: float) -> None:
286
+ self.min = min
287
+ self.max = max
288
+ self._distribution = torch.distributions.Beta(
289
+ torch.tensor([alpha], dtype=torch.float32),
290
+ torch.tensor([beta], dtype=torch.float32),
291
+ )
292
+
293
+ @property
294
+ def alpha(self):
295
+ return self._distribution.concentration1.item()
296
+
297
+ @alpha.setter
298
+ def alpha(self, alpha: float):
299
+ self._distribution = torch.distributions.Beta(
300
+ torch.tensor([alpha], dtype=torch.float32),
301
+ torch.tensor([self.beta], dtype=torch.float32),
302
+ )
303
+
304
+ @property
305
+ def beta(self):
306
+ return self._distribution.concentration0.item()
307
+
308
+ @beta.setter
309
+ def beta(self, beta: float):
310
+ self._distribution = torch.distributions.Beta(
311
+ torch.tensor([self.alpha], dtype=torch.float32),
312
+ torch.tensor([beta], dtype=torch.float32),
313
+ )
314
+
315
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
316
+ return (
317
+ self._distribution.sample((batch_size,)).to(device) * (self.max - self.min)
318
+ + self.min
319
+ ).view(batch_size)
320
+
321
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
322
+ return torch.ones(batch_size, device=device) * self.max
323
+
324
+
325
+ class Chi2RiskLevelSampler(AbstractRiskLevelSampler):
326
+ """Risk-level sampler with a scaled chi2 distribution
327
+
328
+ Distribution properties:
329
+ mean = k*scale + min
330
+ mode = max(k-2, 0)*scale + min
331
+ variance = 2*k*scale**2
332
+
333
+ Args:
334
+ min: minimum risk-level
335
+ scale: scaling factor for the risk-level
336
+ k: Chi2 parameter: degrees of freedom of the distribution
337
+ """
338
+
339
+ def __init__(self, min: int, scale: float, k: int) -> None:
340
+ self.min = min
341
+ self.scale = scale
342
+ self._distribution = torch.distributions.Chi2(
343
+ torch.tensor([k], dtype=torch.float32)
344
+ )
345
+
346
+ @property
347
+ def k(self):
348
+ return self._distribution.df.item()
349
+
350
+ @k.setter
351
+ def k(self, k: int):
352
+ self._distribution = torch.distributions.Chi2(
353
+ torch.tensor([k], dtype=torch.float32)
354
+ )
355
+
356
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
357
+ return (
358
+ self._distribution.sample((batch_size,)).to(device) * self.scale + self.min
359
+ ).view(batch_size)
360
+
361
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
362
+ std = self.scale * math.sqrt(2 * self.k)
363
+ return torch.ones(batch_size, device=device) * std * 3
364
+
365
+
366
+ class LogNormalRiskLevelSampler(AbstractRiskLevelSampler):
367
+ """Risk-level sampler with a scaled Beta distribution
368
+
369
+ Distribution properties:
370
+ mean = exp(mu + sigma**2/2)*scale + min
371
+ mode = exp(mu - sigma**2)*scale + min
372
+ variance = (exp(sigma**2)-1)*exp(2*mu+sigma**2)*scale**2
373
+
374
+ Args:
375
+ min: minimum risk-level
376
+ scale: scaling factor for the risk-level
377
+ mu: First distribution parameter
378
+ sigma: maximum risk-level
379
+ """
380
+
381
+ def __init__(self, min: int, scale: float, mu: float, sigma: float) -> None:
382
+ self.min = min
383
+ self.scale = scale
384
+ self._distribution = torch.distributions.LogNormal(
385
+ torch.tensor([mu], dtype=torch.float32),
386
+ torch.tensor([sigma], dtype=torch.float32),
387
+ )
388
+
389
+ @property
390
+ def mu(self):
391
+ return self._distribution.loc.item()
392
+
393
+ @mu.setter
394
+ def mu(self, mu: float):
395
+ self._distribution = torch.distributions.LogNormal(
396
+ torch.tensor([mu], dtype=torch.float32),
397
+ torch.tensor([self.sigma], dtype=torch.float32),
398
+ )
399
+
400
+ @property
401
+ def sigma(self) -> float:
402
+ return self._distribution.scale.item()
403
+
404
+ @sigma.setter
405
+ def sigma(self, sigma: float):
406
+ self._distribution = torch.distributions.LogNormal(
407
+ torch.tensor([self.mu], dtype=torch.float32),
408
+ torch.tensor([sigma], dtype=torch.float32),
409
+ )
410
+
411
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
412
+ return (
413
+ self._distribution.sample((batch_size,)).to(device) * self.scale + self.min
414
+ ).view(batch_size)
415
+
416
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
417
+ std = (
418
+ (torch.exp(self.sigma.square()) - 1).sqrt()
419
+ * torch.exp(self.mu + self.sigma.square() / 2)
420
+ * self.scale
421
+ )
422
+ return torch.ones(batch_size, device=device) * 3 * std
423
+
424
+
425
+ class LogUniformRiskLevelSampler(AbstractRiskLevelSampler):
426
+ """Risk-level sampler with a reversed log-uniform distribution (increasing density function). Between min and max.
427
+
428
+ Distribution properties:
429
+ mean = (max - min)/ln((max+1)/(min+1)) - 1/scale
430
+ mode = None
431
+ variance = (((max+1)^2 - (min+1)^2)/(2*ln((max+1)/(min+1))) - ((max - min)/ln((max+1)/(min+1)))^2)
432
+
433
+ Args:
434
+ min: minimum risk-level
435
+ max: maximum risk-level
436
+ scale: scale to apply to the sampling before applying exponential,
437
+ the output is rescaled back to fit in bounds [min, max] (the higher the scale the less uniform the distribution)
438
+ """
439
+
440
+ def __init__(self, min: float, max: float, scale: float) -> None:
441
+ assert min >= 0
442
+ assert max > min
443
+ assert scale > 0
444
+ self.min = min
445
+ self.max = max
446
+ self.scale = scale
447
+
448
+ def sample(self, batch_size: int, device: torch.device) -> Tensor:
449
+ scale = self.scale / (self.max - self.min)
450
+ max = self.max * scale
451
+ min = self.min * scale
452
+ return (
453
+ max
454
+ - (
455
+ (
456
+ torch.rand(batch_size, device=device)
457
+ * (math.log(max + 1) - math.log(min + 1))
458
+ + math.log(min + 1)
459
+ ).exp()
460
+ - 1
461
+ )
462
+ + min
463
+ ) / scale
464
+
465
+ def get_highest_risk(self, batch_size: int, device: torch.device) -> Tensor:
466
+ return torch.ones(batch_size, device=device) * self.max
467
+
468
+
469
+ def get_risk_level_sampler(distribution_params: dict) -> AbstractRiskLevelSampler:
470
+ """Function that returns the risk level sampler that matches the given parameters.
471
+ Tries to give a comprehensive feedback if the parameters are not recognized and raise an error.
472
+
473
+ Args:
474
+ Risk distribution should be one of the following types (with different parameter values as desired) :
475
+ {"type": "uniform", "min": 0, "max": 1},
476
+ {"type": "normal", "mean": 0, "sigma": 1},
477
+ {"type": "bernoulli", "p": 0.5, "min": 0, "max": 1},
478
+ {"type": "beta", "alpha": 2, "beta": 5, "min": 0, "max": 1},
479
+ {"type": "chi2", "k": 3, "min": 0, "scale": 1},
480
+ {"type": "log-normal", "mu": 0, "sigma": 1, "min": 0, "scale": 1}
481
+ {"type": "log-uniform", "min": 0, "max": 1, "scale": 1}
482
+
483
+ Raises:
484
+ RuntimeError: If the given parameter dictionary does not match one of the expected formats, raise a comprehensive error.
485
+
486
+ Returns:
487
+ A risk level sampler matching the given parameters.
488
+ """
489
+ known_types = [
490
+ "uniform",
491
+ "normal",
492
+ "bernoulli",
493
+ "beta",
494
+ "chi2",
495
+ "log-normal",
496
+ "log-uniform",
497
+ ]
498
+ try:
499
+ if distribution_params["type"].lower() == "uniform":
500
+ expected_params = inspect.getfullargspec(UniformRiskLevelSampler)[0][1:]
501
+ return UniformRiskLevelSampler(
502
+ distribution_params["min"], distribution_params["max"]
503
+ )
504
+ elif distribution_params["type"].lower() == "normal":
505
+ expected_params = inspect.getfullargspec(NormalRiskLevelSampler)[0][1:]
506
+ return NormalRiskLevelSampler(
507
+ distribution_params["mean"], distribution_params["sigma"]
508
+ )
509
+ elif distribution_params["type"].lower() == "bernoulli":
510
+ expected_params = inspect.getfullargspec(BernoulliRiskLevelSampler)[0][1:]
511
+ return BernoulliRiskLevelSampler(
512
+ distribution_params["min"],
513
+ distribution_params["max"],
514
+ distribution_params["p"],
515
+ )
516
+ elif distribution_params["type"].lower() == "beta":
517
+ expected_params = inspect.getfullargspec(BetaRiskLevelSampler)[0][1:]
518
+ return BetaRiskLevelSampler(
519
+ distribution_params["min"],
520
+ distribution_params["max"],
521
+ distribution_params["alpha"],
522
+ distribution_params["beta"],
523
+ )
524
+ elif distribution_params["type"].lower() == "chi2":
525
+ expected_params = inspect.getfullargspec(Chi2RiskLevelSampler)[0][1:]
526
+ return Chi2RiskLevelSampler(
527
+ distribution_params["min"],
528
+ distribution_params["scale"],
529
+ distribution_params["k"],
530
+ )
531
+ elif distribution_params["type"].lower() == "log-normal":
532
+ expected_params = inspect.getfullargspec(LogNormalRiskLevelSampler)[0][1:]
533
+ return LogNormalRiskLevelSampler(
534
+ distribution_params["min"],
535
+ distribution_params["scale"],
536
+ distribution_params["mu"],
537
+ distribution_params["sigma"],
538
+ )
539
+ elif distribution_params["type"].lower() == "log-uniform":
540
+ expected_params = inspect.getfullargspec(LogUniformRiskLevelSampler)[0][1:]
541
+ return LogUniformRiskLevelSampler(
542
+ distribution_params["min"],
543
+ distribution_params["max"],
544
+ distribution_params["scale"],
545
+ )
546
+ else:
547
+ raise RuntimeError(
548
+ f"Distribution {distribution_params['type']} is unknown. It should be one of {known_types}."
549
+ )
550
+ except KeyError:
551
+ if "type" in distribution_params:
552
+ raise RuntimeError(
553
+ f"The distribution '{distribution_params['type']}' is known but the given parameters {distribution_params} do not match the expected parameters {expected_params}."
554
+ )
555
+ else:
556
+ raise RuntimeError(
557
+ f"The given distribution parameters {distribution_params} do not define the distribution type in the field 'type'. Please add a field 'type' and set it to one of the handeled types: {known_types}."
558
+ )
559
+
560
+
561
+ if __name__ == "__main__":
562
+ import matplotlib.pyplot as plt
563
+
564
+ sampler = get_risk_level_sampler(
565
+ {"type": "log-uniform", "min": 0, "max": 1, "scale": 10}
566
+ )
567
+ # sampler = get_risk_level_sampler({"type": "normal", "mean": 0, "sigma": 1})
568
+ a = sampler.sample(10000, "cpu").detach().numpy()
569
+ _ = plt.hist(a, bins="auto") # arguments are passed to np.histogram
570
+ plt.title("Histogram with 'auto' bins")
571
+ plt.show()
risk_biased/utils/torch_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ @torch.jit.script
8
+ def torch_linspace(start: Tensor, stop: Tensor, num: int) -> torch.Tensor:
9
+ """
10
+ Copy-pasted from https://github.com/pytorch/pytorch/issues/61292
11
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
12
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
13
+ """
14
+ # create a tensor of 'num' steps from 0 to 1
15
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
16
+
17
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
18
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
19
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
20
+ for i in range(start.ndim):
21
+ steps = steps.unsqueeze(-1)
22
+
23
+ # the output starts at 'start' and increments until 'stop' in each dimension
24
+ out = start[None] + steps * (stop - start)[None]
25
+
26
+ return out
27
+
28
+
29
+ def load_weights(
30
+ model: torch.nn.Module, checkpoint: dict, strict=True
31
+ ) -> torch.nn.Module:
32
+ """This function is used instead of the one provided by pytorch lightning
33
+ because for unexplained reasons, the pytorch lightning load function did
34
+ not behave as intended: loading several times from the same checkpoint
35
+ resulted in different loaded weight values...
36
+
37
+ Args:
38
+ model: a model in which new weights should be set
39
+ checkpoint: a loaded pytorch checkpoint (probably resulting from torch.load(filename))
40
+ strict: Default to True, wether to fail if
41
+
42
+ """
43
+ if not strict:
44
+ model_dict = model.state_dict()
45
+ pretrained_dict = {
46
+ k: v for k, v in checkpoint["state_dict"].items() if k in model_dict
47
+ }
48
+ diff1 = checkpoint["state_dict"].keys() - model_dict.keys()
49
+ if diff1:
50
+ warnings.warn(
51
+ f"Found keys {diff1} in checkpoint without any match in the model, ignoring corresponding values."
52
+ )
53
+ diff2 = model_dict.keys() - checkpoint["state_dict"].keys()
54
+ if diff2:
55
+ warnings.warn(
56
+ f"Missing keys {diff2} from the checkpoint, the corresponding weights will keep their initial values."
57
+ )
58
+ pretrained_dict = {
59
+ k: v for k, v in checkpoint["state_dict"].items() if k in model_dict
60
+ }
61
+ model_dict.update(pretrained_dict)
62
+ else:
63
+ model_dict = checkpoint["state_dict"]
64
+
65
+ model.load_state_dict(model_dict, strict=strict)
66
+ return model
risk_biased/utils/waymo_dataloader.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+ from cv2 import repeat
3
+ from einops import rearrange, repeat
4
+ from torch.utils.data import Dataset
5
+ from torch.utils.data import DataLoader
6
+ import torch
7
+ from torch import Tensor
8
+ import numpy as np
9
+ import pickle
10
+ import os
11
+
12
+ from mmcv import Config
13
+
14
+
15
+ class WaymoDataset(Dataset):
16
+ """
17
+ Dataset loader for custom preprocessed files of Waymo data.
18
+ Args:
19
+ path: path to the dataset directory
20
+ args: global settings
21
+ """
22
+
23
+ def __init__(self, cfg: Config, split: str, input_angle: bool = True):
24
+ super(WaymoDataset, self).__init__()
25
+ self.p_exchange_two_first = 1
26
+ if "val" in split.lower():
27
+ path = cfg.val_dataset_path
28
+ elif "test" in split.lower():
29
+ path = cfg.test_dataset_path
30
+ elif "sample" in split.lower():
31
+ path = cfg.sample_dataset_path
32
+ else:
33
+ path = cfg.train_dataset_path
34
+ self.p_exchange_two_first = cfg.p_exchange_two_first
35
+
36
+ self.file_list = [
37
+ os.path.join(path, name)
38
+ for name in os.listdir(path)
39
+ if os.path.isfile(os.path.join(path, name))
40
+ ]
41
+ self.normalize = cfg.normalize_angle
42
+ # self.load_dataset(path, 16)
43
+ # self.idx_list = list(self.dataset.keys())
44
+ self.input_angle = input_angle
45
+ self.hist_len = cfg.num_steps
46
+ self.fut_len = cfg.num_steps_future
47
+ self.time_len = self.hist_len + self.fut_len
48
+ self.min_num_obs = cfg.min_num_observation
49
+ self.max_size_lane = cfg.max_size_lane
50
+ self.random_rotation = cfg.random_rotation
51
+ self.random_translation = cfg.random_translation
52
+ self.angle_std = cfg.angle_std
53
+ self.translation_distance_std = cfg.translation_distance_std
54
+ self.max_num_agents = cfg.max_num_agents
55
+ self.max_num_objects = cfg.max_num_objects
56
+ self.state_dim = cfg.state_dim
57
+ self.map_state_dim = cfg.map_state_dim
58
+ self.dt = cfg.dt
59
+
60
+ if "val" in os.path.basename(path).lower():
61
+ self.dataset_size_limit = cfg.val_dataset_size_limit
62
+ else:
63
+ self.dataset_size_limit = cfg.train_dataset_size_limit
64
+
65
+ def __len__(self):
66
+ if self.dataset_size_limit is not None:
67
+ return min(len(self.file_list), self.dataset_size_limit)
68
+ else:
69
+ return len(self.file_list)
70
+
71
+ def __getitem__(
72
+ self, idx: int
73
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
74
+ """
75
+ Get the item at index idx in the dataset. Normalize the scene and output absolute angle and position.
76
+ Returns:
77
+ trajectories, mask, mask_loss, lanes, mask_lanes, angle, mean_position
78
+ """
79
+ selected_file = self.file_list[idx]
80
+ with open(selected_file, "rb") as handle:
81
+ dataset = pickle.load(handle)
82
+ rel_state_all = dataset["traj"]
83
+ mask_all = dataset["mask_traj"]
84
+ mask_loss = dataset["mask_to_predict"]
85
+ rel_lane_all = dataset["lanes"]
86
+ mask_lane_all = dataset["mask_lanes"]
87
+ mean_pos = dataset["mean_pos"]
88
+ assert (
89
+ (
90
+ rel_state_all[self.hist_len + 5 :, :, :2][mask_all[self.hist_len + 5 :]]
91
+ != 0
92
+ )
93
+ .any(-1)
94
+ .all()
95
+ )
96
+ assert (
97
+ (
98
+ rel_state_all[self.hist_len + 5 :, :, :2][
99
+ mask_loss[self.hist_len + 5 :]
100
+ ]
101
+ != 0
102
+ )
103
+ .any(-1)
104
+ .all()
105
+ )
106
+ if "lane_states" in dataset.keys():
107
+ lane_states = dataset["lane_states"]
108
+ else:
109
+ lane_states = None
110
+ if np.random.rand() > self.p_exchange_two_first:
111
+ rel_state_all[:, [0, 1]] = rel_state_all[:, [1, 0]]
112
+ mask_all[:, [0, 1]] = mask_all[:, [1, 0]]
113
+ mask_loss[:, [0, 1]] = mask_loss[:, [1, 0]]
114
+ assert (
115
+ (
116
+ rel_state_all[self.hist_len + 5 :, :, :2][mask_all[self.hist_len + 5 :]]
117
+ != 0
118
+ )
119
+ .any(-1)
120
+ .all()
121
+ )
122
+ assert (
123
+ (
124
+ rel_state_all[self.hist_len + 5 :, :, :2][
125
+ mask_loss[self.hist_len + 5 :]
126
+ ]
127
+ != 0
128
+ )
129
+ .any(-1)
130
+ .all()
131
+ )
132
+ if self.normalize:
133
+ angle = rel_state_all[self.hist_len - 1, 1, 2]
134
+
135
+ if self.random_rotation:
136
+ if self.normalize:
137
+ angle += np.random.normal(0, self.angle_std)
138
+ else:
139
+ angle += np.random.uniform(-np.pi, np.pi)
140
+ if self.random_translation:
141
+ distance = (
142
+ np.random.normal([0, 0], self.translation_distance_std, 2)
143
+ * mask_all[self.hist_len - 1 : self.hist_len, :, None]
144
+ - rel_state_all[self.hist_len - 1 : self.hist_len, 1:2, :2]
145
+ )
146
+ else:
147
+ distance = -rel_state_all[self.hist_len - 1 : self.hist_len, 1:2, :2]
148
+
149
+ rel_state_all[:, :, :2] += distance
150
+ rel_lane_all[:, :, :2] += distance
151
+ mean_pos += distance[0, 0, :]
152
+ rel_state_all = self.scene_rotation(rel_state_all, -angle)
153
+ rel_lane_all = self.scene_rotation(rel_lane_all, -angle)
154
+
155
+ else:
156
+ if self.random_translation:
157
+ distance = np.random.normal([0, 0], self.translation_distance_std, 2)
158
+ rel_state_all = (
159
+ rel_state_all
160
+ + mask_all[self.hist_len - 1 : self.hist_len, :, None] * distance
161
+ )
162
+ rel_lane_all = (
163
+ rel_lane_all
164
+ + mask_all[self.hist_len - 1 : self.hist_len, :, None] * distance
165
+ )
166
+
167
+ if self.random_rotation:
168
+ angle = np.random.uniform(0, 2 * np.pi)
169
+ rel_state_all = self.scene_rotation(rel_state_all, angle)
170
+ rel_lane_all = self.scene_rotation(rel_lane_all, angle)
171
+ else:
172
+ angle = 0
173
+ return (
174
+ rel_state_all,
175
+ mask_all,
176
+ mask_loss,
177
+ rel_lane_all,
178
+ mask_lane_all,
179
+ lane_states,
180
+ angle,
181
+ mean_pos,
182
+ idx,
183
+ )
184
+
185
+ @staticmethod
186
+ def scene_rotation(coor: np.ndarray, angle: float) -> np.ndarray:
187
+ """
188
+ Rotate all the coordinates with the same angle
189
+ Args:
190
+ coor: array of x, y coordinates
191
+ angle: radiants to rotate the coordinates by
192
+ Returns:
193
+ coor_rotated
194
+ """
195
+ rot_matrix = np.zeros((2, 2))
196
+ c = np.cos(angle)
197
+ s = np.sin(angle)
198
+ rot_matrix[0, 0] = c
199
+ rot_matrix[0, 1] = -s
200
+ rot_matrix[1, 0] = s
201
+ rot_matrix[1, 1] = c
202
+ coor[..., :2] = np.matmul(
203
+ rot_matrix, np.expand_dims(coor[..., :2], axis=-1)
204
+ ).squeeze(-1)
205
+ if coor.shape[-1] > 2:
206
+ coor[..., 2] += angle
207
+ if coor.shape[-1] >= 5:
208
+ coor[..., 3:5] = np.matmul(
209
+ rot_matrix, np.expand_dims(coor[..., 3:5], axis=-1)
210
+ ).squeeze(-1)
211
+ return coor
212
+
213
+ def fill_past(self, past, mask_past):
214
+ current_velocity = past[..., 0, 3:5]
215
+ for t in range(1, past.shape[-2]):
216
+ current_velocity = torch.where(
217
+ mask_past[..., t, None], past[..., t, 3:5], current_velocity
218
+ )
219
+ past[..., t, 3:5] = current_velocity
220
+ predicted_position = past[..., t - 1, :2] + current_velocity * self.dt
221
+ past[..., t, :2] = torch.where(
222
+ mask_past[..., t, None], past[..., t, :2], predicted_position
223
+ )
224
+ return past
225
+
226
+ def collate_fn(
227
+ self, samples: List
228
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
229
+ """
230
+ Assemble trajectories into batches with 0-padding.
231
+ Args:
232
+ samples: list of sampled trajectories (list of outputs of __getitem__)
233
+ Returns:
234
+ (starred dimensions have different values from one batch to the next but the ones with the same name are consistent within the batch)
235
+ batch : ((batch_size, num_agents*, num_steps, state_dim), # past trajectories of all agents in the scene
236
+ (batch_size, num_agents*, num_steps), # mask past False where past trajectories are padding data
237
+ (batch_size, num_agents*, num_steps_future, state_dim), # future trajectory
238
+ (batch_size, num_agents*, num_steps_future), # mask future False where future trajectories are padding data
239
+ (batch_size, num_agents*, num_steps_future), # mask loss False where future trajectories are not to be predicted
240
+ (batch_size, num_objects*, object_seq_len*, map_state_dim),# map object sequences in the scene
241
+ (batch_size, num_objects*, object_seq_len*), # mask map False where map objects are padding data
242
+ (batch_size, num_agents*, state_dim), # position offset of all agents relative to ego at present time
243
+ (batch_size, num_steps, state_dim), # ego past trajectory
244
+ (batch_size, num_steps_future, state_dim)) # ego future trajectory
245
+ """
246
+ max_n_vehicle = 50
247
+ max_n_lanes = 0
248
+ for (
249
+ coor,
250
+ mask,
251
+ mask_loss,
252
+ lanes,
253
+ mask_lanes,
254
+ lane_states,
255
+ mean_angle,
256
+ mean_pos,
257
+ idx,
258
+ ) in samples:
259
+ # time_len_coor = self._count_last_obs(coor, hist_len)
260
+ # num_vehicle = np.sum(time_len_coor > self.min_num_obs)
261
+ num_vehicle = coor.shape[1]
262
+ num_lanes = lanes.shape[1]
263
+ max_n_vehicle = max(num_vehicle, max_n_vehicle)
264
+ max_n_lanes = max(num_lanes, max_n_lanes)
265
+ if max_n_vehicle <= 0:
266
+ raise RuntimeError
267
+ data_batch = np.zeros(
268
+ [self.time_len, len(samples), max_n_vehicle, self.state_dim]
269
+ )
270
+ mask_batch = np.zeros([self.time_len, len(samples), max_n_vehicle])
271
+ mask_loss_batch = np.zeros([self.time_len, len(samples), max_n_vehicle])
272
+ lane_batch = np.zeros(
273
+ [self.max_size_lane, len(samples), max_n_lanes, self.map_state_dim]
274
+ )
275
+ mask_lane_batch = np.zeros([self.max_size_lane, len(samples), max_n_lanes])
276
+ mean_angle_batch = np.zeros([len(samples)])
277
+ mean_pos_batch = np.zeros([len(samples), 2])
278
+ tag_list = np.zeros([len(samples)])
279
+ idx_list = [0 for _ in range(len(samples))]
280
+
281
+ for sample_ind, (
282
+ coor,
283
+ mask,
284
+ mask_loss,
285
+ lanes,
286
+ mask_lanes,
287
+ lane_states,
288
+ mean_angle,
289
+ mean_pos,
290
+ idx,
291
+ ) in enumerate(samples):
292
+ data_batch[:, sample_ind, : coor.shape[1], :] = coor[: self.time_len, :, :]
293
+ mask_batch[:, sample_ind, : mask.shape[1]] = mask[: self.time_len, :]
294
+ mask_loss_batch[:, sample_ind, : mask.shape[1]] = mask_loss[
295
+ : self.time_len, :
296
+ ]
297
+ lane_batch[: lanes.shape[0], sample_ind, : lanes.shape[1], :2] = lanes
298
+ if lane_states is not None:
299
+ lane_states = repeat(
300
+ lane_states[:, : self.hist_len],
301
+ "objects time features -> one objects (time features)",
302
+ one=1,
303
+ )
304
+ lane_batch[
305
+ : lanes.shape[0], sample_ind, : lanes.shape[1], 2:
306
+ ] = lane_states
307
+ mask_lane_batch[
308
+ : mask_lanes.shape[0], sample_ind, : mask_lanes.shape[1]
309
+ ] = mask_lanes
310
+ mean_angle_batch[sample_ind] = mean_angle
311
+ mean_pos_batch[sample_ind, :] = mean_pos
312
+ # tag_list[sample_ind] = self.dataset[idx]["tag"]
313
+ idx_list[sample_ind] = idx
314
+
315
+ data_batch = torch.from_numpy(data_batch.astype("float32"))
316
+ mask_batch = torch.from_numpy(mask_batch.astype("bool"))
317
+ lane_batch = torch.from_numpy(lane_batch.astype("float32"))
318
+ mask_lane_batch = torch.from_numpy(mask_lane_batch.astype("bool"))
319
+ mean_pos_batch = torch.from_numpy(mean_pos_batch.astype("float32"))
320
+ mask_loss_batch = torch.from_numpy(mask_loss_batch.astype("bool"))
321
+
322
+ data_batch = rearrange(
323
+ data_batch, "time batch agents features -> batch agents time features"
324
+ )
325
+ mask_batch = rearrange(mask_batch, "time batch agents -> batch agents time")
326
+ mask_loss_batch = rearrange(
327
+ mask_loss_batch, "time batch agents -> batch agents time"
328
+ )
329
+ lane_batch = rearrange(
330
+ lane_batch,
331
+ "object_seq_len batch objects features-> batch objects object_seq_len features",
332
+ )
333
+ mask_lane_batch = rearrange(
334
+ mask_lane_batch,
335
+ "object_seq_len batch objects -> batch objects object_seq_len",
336
+ )
337
+
338
+ # The two first agents are the ones interacting, others are sorted by distance from the first agent
339
+ # Objects are also sorted by distance from the first agent
340
+ # Therefore, the limits in number, max_num_agents and max_num_objects can be seen as adaptative distance limits.
341
+
342
+ if not self.input_angle:
343
+ data_batch = torch.cat((data_batch[..., :2], data_batch[..., 3:]), dim=-1)
344
+ traj_past = data_batch[:, : self.max_num_agents, : self.hist_len, :]
345
+ mask_past = mask_batch[:, : self.max_num_agents, : self.hist_len]
346
+ traj_fut = data_batch[
347
+ :, : self.max_num_agents, self.hist_len : self.hist_len + self.fut_len, :
348
+ ]
349
+ mask_fut = mask_batch[
350
+ :, : self.max_num_agents, self.hist_len : self.hist_len + self.fut_len
351
+ ]
352
+ ego_past = data_batch[:, 0:1, : self.hist_len, :]
353
+ ego_fut = data_batch[:, 0:1, self.hist_len : self.hist_len + self.fut_len, :]
354
+
355
+ lane_batch = lane_batch[:, : self.max_num_objects]
356
+ mask_lane_batch = mask_lane_batch[:, : self.max_num_objects]
357
+
358
+ # Define what to predict (could be from Waymo's label of what to predict or the other agent that interacts with the ego...)
359
+ mask_loss_batch = torch.logical_and(
360
+ mask_loss_batch[
361
+ :, : self.max_num_agents, self.hist_len : self.hist_len + self.fut_len
362
+ ],
363
+ mask_past.any(-1, keepdim=True),
364
+ )
365
+ # Remove all other agents so the model should only predict the first one
366
+ mask_loss_batch[:, 0] = False
367
+ mask_loss_batch[:, 2:] = False
368
+
369
+ # Normalize...
370
+ # traj_past = self.fill_past(traj_past, mask_past)
371
+ dynamic_state_size = 5 if self.input_angle else 4
372
+ offset_batch = traj_past[..., -1, :dynamic_state_size].clone()
373
+ traj_past[..., :dynamic_state_size] = traj_past[
374
+ ..., :dynamic_state_size
375
+ ] - offset_batch.unsqueeze(-2)
376
+ traj_fut[..., :dynamic_state_size] = traj_fut[
377
+ ..., :dynamic_state_size
378
+ ] - offset_batch.unsqueeze(-2)
379
+
380
+ return (
381
+ traj_past,
382
+ mask_past,
383
+ traj_fut,
384
+ mask_fut,
385
+ mask_loss_batch,
386
+ lane_batch,
387
+ mask_lane_batch,
388
+ offset_batch,
389
+ ego_past,
390
+ ego_fut,
391
+ )
392
+
393
+
394
+ class WaymoDataloaders:
395
+ def __init__(self, cfg: Config) -> None:
396
+ self.cfg = cfg
397
+
398
+ def sample_dataloader(self) -> DataLoader:
399
+ """Setup and return sample DataLoader
400
+
401
+ Returns:
402
+ DataLoader: sample DataLoader
403
+ """
404
+ dataset = WaymoDataset(self.cfg, "sample")
405
+ sample_loader = DataLoader(
406
+ dataset=dataset,
407
+ batch_size=self.cfg.batch_size,
408
+ shuffle=False,
409
+ num_workers=self.cfg.num_workers,
410
+ collate_fn=dataset.collate_fn,
411
+ drop_last=True,
412
+ )
413
+ return sample_loader
414
+
415
+ def val_dataloader(
416
+ self, drop_last=True, shuffle=False, input_angle=True
417
+ ) -> DataLoader:
418
+ """Setup and return validation DataLoader
419
+
420
+ Returns:
421
+ DataLoader: validation DataLoader
422
+ """
423
+ dataset = WaymoDataset(self.cfg, "val", input_angle)
424
+ val_loader = DataLoader(
425
+ dataset=dataset,
426
+ batch_size=self.cfg.batch_size,
427
+ shuffle=shuffle,
428
+ num_workers=self.cfg.num_workers,
429
+ collate_fn=dataset.collate_fn,
430
+ drop_last=drop_last,
431
+ )
432
+ torch.cuda.empty_cache()
433
+ return val_loader
434
+
435
+ def train_dataloader(
436
+ self, drop_last=True, shuffle=True, input_angle=True
437
+ ) -> DataLoader:
438
+ """Setup and return training DataLoader
439
+
440
+ Returns:
441
+ DataLoader: training DataLoader
442
+ """
443
+ dataset = WaymoDataset(self.cfg, "train", input_angle)
444
+ train_loader = DataLoader(
445
+ dataset=dataset,
446
+ batch_size=self.cfg.batch_size,
447
+ shuffle=shuffle,
448
+ num_workers=self.cfg.num_workers,
449
+ collate_fn=dataset.collate_fn,
450
+ drop_last=drop_last,
451
+ )
452
+ torch.cuda.empty_cache()
453
+ return train_loader
454
+
455
+ def test_dataloader(self) -> DataLoader:
456
+ """Setup and return test DataLoader
457
+
458
+ Returns:
459
+ DataLoader: test DataLoader
460
+ """
461
+ raise NotImplementedError("The waymo dataloader cannot load test samples yet.")
462
+
463
+ @staticmethod
464
+ def unnormalize_trajectory(
465
+ input: torch.Tensor, offset: torch.Tensor
466
+ ) -> torch.Tensor:
467
+ """Unnormalize trajectory by adding offset to input
468
+
469
+ Args:
470
+ input : (..., (n_sample), num_steps_future, state_dim) tensor of future
471
+ trajectory y
472
+ offset : (..., state_dim) tensor of offset to add to y
473
+
474
+ Returns:
475
+ Unnormalized trajectory that has the same size as input
476
+ """
477
+ assert offset.ndim == 3
478
+ batch_size, num_agents = offset.shape[:2]
479
+ offset_state_dim = offset.shape[-1]
480
+ assert offset_state_dim <= input.shape[-1]
481
+ assert input.shape[0] == batch_size
482
+ assert input.shape[1] == num_agents
483
+ input_copy = input.clone()
484
+
485
+ input_copy[..., :offset_state_dim] = input_copy[
486
+ ..., :offset_state_dim
487
+ ] + offset[..., : input.shape[-1]].reshape(
488
+ [batch_size, num_agents, *([1] * (input.ndim - 3)), offset_state_dim]
489
+ )
490
+ return input_copy
scripts/eval_scripts/compute_stats.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import pickle
4
+ import sys
5
+
6
+ from functools import partial
7
+ from inspect import signature
8
+ import matplotlib.pyplot as plt
9
+ from tqdm import tqdm
10
+
11
+ from einops import repeat
12
+ import fire
13
+ import numpy as np
14
+ from pytorch_lightning.utilities.seed import seed_everything
15
+ import torch
16
+
17
+ from risk_biased.utils.config_argparse import config_argparse
18
+ from risk_biased.utils.cost import TTCCostTorch, TTCCostParams, get_cost
19
+ from risk_biased.utils.risk import get_risk_estimator
20
+
21
+ from risk_biased.utils.load_model import load_from_config
22
+
23
+
24
+ def to_device(batch, device):
25
+ output = []
26
+ for item in batch:
27
+ output.append(item.to(device))
28
+ return output
29
+
30
+
31
+ class CPU_Unpickler(pickle.Unpickler):
32
+ def find_class(self, module, name):
33
+ if module == "torch.storage" and name == "_load_from_bytes":
34
+ return lambda b: torch.load(io.BytesIO(b), map_location="cpu")
35
+ else:
36
+ return super().find_class(module, name)
37
+
38
+
39
+ def distance(pred, truth):
40
+ """
41
+ pred (Tensor): (..., time, xy)
42
+ truth (Tensor): (..., time, xy)
43
+ mask_loss (Tensor): (..., time) Defaults to None.
44
+ """
45
+ return torch.sqrt(torch.sum(torch.square(pred[..., :2] - truth[..., :2]), -1))
46
+
47
+
48
+ def compute_metrics(
49
+ predictor,
50
+ batch,
51
+ cost,
52
+ risk_levels,
53
+ risk_estimator,
54
+ dt,
55
+ unnormalizer,
56
+ n_samples_risk,
57
+ n_samples_stats,
58
+ ):
59
+
60
+ # risk_unbiased
61
+ # risk_biased
62
+ # cost
63
+ # FDE: unbiased, biased(risk_level=[0, 0.3, 0.5, 0.8, 1]) (for all samples so minFDE can be computed later)
64
+ # ADE (for all samples so minADE can be computed later)
65
+
66
+ x, mask_x, y, mask_y, mask_loss, map, mask_map, offset, x_ego, y_ego = batch
67
+ mask_z = mask_x.any(-1)
68
+
69
+ _, z_mean_inference, z_log_std_inference = predictor.model(
70
+ x,
71
+ mask_x,
72
+ map,
73
+ mask_map,
74
+ offset=offset,
75
+ x_ego=x_ego,
76
+ y_ego=y_ego,
77
+ risk_level=None,
78
+ )
79
+
80
+ latent_distribs = {
81
+ "inference": {
82
+ "mean": z_mean_inference[:, 1].detach().cpu(),
83
+ "log_std": z_log_std_inference[:, 1].detach().cpu(),
84
+ }
85
+ }
86
+ inference_distances = []
87
+ cost_list = []
88
+ # Cut the number of samples in packs to avoid out-of-memory problems
89
+ # Compute and store cost for all packs
90
+ for _ in range(n_samples_risk // n_samples_stats):
91
+ z_samples_inference = predictor.model.inference_encoder.sample(
92
+ z_mean_inference,
93
+ z_log_std_inference,
94
+ n_samples=n_samples_stats,
95
+ )
96
+
97
+ y_samples = predictor.model.decode(
98
+ z_samples=z_samples_inference,
99
+ mask_z=mask_z,
100
+ x=x,
101
+ mask_x=mask_x,
102
+ map=map,
103
+ mask_map=mask_map,
104
+ offset=offset,
105
+ )
106
+
107
+ mask_loss_samples = repeat(mask_loss, "b a t -> b a s t", s=n_samples_stats)
108
+ # Computing unbiased cost
109
+ cost_list.append(
110
+ get_cost(
111
+ cost,
112
+ x,
113
+ y_samples,
114
+ offset,
115
+ x_ego,
116
+ y_ego,
117
+ dt,
118
+ unnormalizer,
119
+ mask_loss_samples,
120
+ )[:, 1:2]
121
+ )
122
+ inference_distances.append(distance(y_samples, y.unsqueeze(2))[:, 1:2])
123
+ cost_dic = {}
124
+ cost_dic["inference"] = torch.cat(cost_list, 2).detach().cpu()
125
+ distance_dic = {}
126
+ distance_dic["inference"] = torch.cat(inference_distances, 2).detach().cpu()
127
+
128
+ # Set up the output risk tensor
129
+ risk_dic = {}
130
+
131
+ # Loop on risk_level values to fill the risk estimation for each value and compute stats at each risk level
132
+ for rl in risk_levels:
133
+ risk_level = (
134
+ torch.ones(
135
+ (x.shape[0], x.shape[1]),
136
+ device=x.device,
137
+ )
138
+ * rl
139
+ )
140
+ risk_dic[f"biased_{rl}"] = risk_estimator(
141
+ risk_level[:, 1:2].detach().cpu(), cost_dic["inference"]
142
+ )
143
+
144
+ y_samples_biased, z_mean_biased, z_log_std_biased = predictor.model(
145
+ x,
146
+ mask_x,
147
+ map,
148
+ mask_map,
149
+ offset=offset,
150
+ x_ego=x_ego,
151
+ y_ego=y_ego,
152
+ risk_level=risk_level,
153
+ n_samples=n_samples_stats,
154
+ )
155
+ latent_distribs[f"biased_{rl}"] = {
156
+ "mean": z_mean_biased[:, 1].detach().cpu(),
157
+ "log_std": z_log_std_biased[:, 1].detach().cpu(),
158
+ }
159
+
160
+ distance_dic[f"biased_{rl}"] = (
161
+ distance(y_samples_biased, y.unsqueeze(2))[:, 1].detach().cpu()
162
+ )
163
+ cost_dic[f"biased_{rl}"] = (
164
+ get_cost(
165
+ cost,
166
+ x,
167
+ y_samples_biased,
168
+ offset,
169
+ x_ego,
170
+ y_ego,
171
+ dt,
172
+ unnormalizer,
173
+ mask_loss_samples,
174
+ )[:, 1]
175
+ .detach()
176
+ .cpu()
177
+ )
178
+
179
+ # Return risks for the batch and all risk values
180
+ return {
181
+ "risk": risk_dic,
182
+ "cost": cost_dic,
183
+ "distance": distance_dic,
184
+ "latent_distribs": latent_distribs,
185
+ "mask": mask_loss[:, 1].detach().cpu(),
186
+ }
187
+
188
+
189
+ def cat_metrics_rec(metrics1, metrics2, cat_to):
190
+ for key in metrics1.keys():
191
+ if key not in metrics2.keys():
192
+ raise RuntimeError(
193
+ f"Trying to concatenate objects with different keys: {key} is not in second argument keys."
194
+ )
195
+ elif isinstance(metrics1[key], dict):
196
+ if key not in cat_to.keys():
197
+ cat_to[key] = {}
198
+ cat_metrics_rec(metrics1[key], metrics2[key], cat_to[key])
199
+ elif isinstance(metrics1[key], torch.Tensor):
200
+ cat_to[key] = torch.cat((metrics1[key], metrics2[key]), 0)
201
+
202
+
203
+ def cat_metrics(metrics1, metrics2):
204
+ out = {}
205
+ cat_metrics_rec(metrics1, metrics2, out)
206
+ return out
207
+
208
+
209
+ def masked_mean_std_ste(data, mask):
210
+ mask = mask.view(data.shape)
211
+ norm = mask.sum().clamp_min(1)
212
+ mean = (data * mask).sum() / norm
213
+ std = torch.sqrt(((data - mean) * mask).square().sum() / norm)
214
+ return mean.item(), std.item(), (std / torch.sqrt(norm)).item()
215
+
216
+
217
+ def masked_mean_range(data, mask):
218
+ data = data[mask]
219
+ mean = data.mean()
220
+ min = torch.quantile(data, 0.05)
221
+ max = torch.quantile(data, 0.95)
222
+ return mean, min, max
223
+
224
+
225
+ def masked_mean_dim(data, mask, dim):
226
+ norm = mask.sum(dim).clamp_min(1)
227
+ mean = (data * mask).sum(dim) / norm
228
+ return mean
229
+
230
+
231
+ def plot_risk_error(metrics, risk_levels, risk_estimator, max_n_samples, path_save):
232
+ cost_inference = metrics["cost"]["inference"]
233
+ cost_biased_0 = metrics["cost"]["biased_0"]
234
+ mask = metrics["mask"].any(1)
235
+ ones_tensor = torch.ones(mask.shape[0])
236
+ n_samples = np.minimum(cost_biased_0.shape[1], max_n_samples)
237
+
238
+ for rl in risk_levels:
239
+ key = f"biased_{rl}"
240
+ reference_risk = metrics["risk"][key]
241
+ mean_inference_risk_error_per_samples = np.zeros(n_samples - 1)
242
+ min_inference_risk_error_per_samples = np.zeros(n_samples - 1)
243
+ max_inference_risk_error_per_samples = np.zeros(n_samples - 1)
244
+ # mean_biased_0_risk_error_per_samples = np.zeros(n_samples-1)
245
+ # min_biased_0_risk_error_per_samples = np.zeros(n_samples-1)
246
+ # max_biased_0_risk_error_per_samples = np.zeros(n_samples-1)
247
+ mean_biased_risk_error_per_samples = np.zeros(n_samples - 1)
248
+ min_biased_risk_error_per_samples = np.zeros(n_samples - 1)
249
+ max_biased_risk_error_per_samples = np.zeros(n_samples - 1)
250
+ risk_level_tensor = ones_tensor * rl
251
+ for sub_samples in range(1, n_samples):
252
+ perm = torch.randperm(metrics["cost"][key].shape[1])[:sub_samples]
253
+ risk_error_biased = metrics["cost"][key][:, perm].mean(1) - reference_risk
254
+ (
255
+ mean_biased_risk_error_per_samples[sub_samples - 1],
256
+ min_biased_risk_error_per_samples[sub_samples - 1],
257
+ max_biased_risk_error_per_samples[sub_samples - 1],
258
+ ) = masked_mean_range(risk_error_biased, mask)
259
+ risk_error_inference = (
260
+ risk_estimator(risk_level_tensor, cost_inference[:, :, :sub_samples])
261
+ - reference_risk
262
+ )
263
+ (
264
+ mean_inference_risk_error_per_samples[sub_samples - 1],
265
+ min_inference_risk_error_per_samples[sub_samples - 1],
266
+ max_inference_risk_error_per_samples[sub_samples - 1],
267
+ ) = masked_mean_range(risk_error_inference, mask)
268
+ # risk_error_biased_0 = risk_estimator(risk_level_tensor, cost_biased_0[:, :sub_samples]) - reference_risk
269
+ # (mean_biased_0_risk_error_per_samples[sub_samples-1], min_biased_0_risk_error_per_samples[sub_samples-1], max_biased_0_risk_error_per_samples[sub_samples-1]) = masked_mean_range(risk_error_biased_0, mask)
270
+
271
+ plt.plot(
272
+ range(1, n_samples),
273
+ mean_inference_risk_error_per_samples,
274
+ label="Inference",
275
+ )
276
+ plt.fill_between(
277
+ range(1, n_samples),
278
+ min_inference_risk_error_per_samples,
279
+ max_inference_risk_error_per_samples,
280
+ alpha=0.3,
281
+ )
282
+
283
+ # plt.plot(range(1, n_samples), mean_biased_0_risk_error_per_samples, label="Unbiased")
284
+ # plt.fill_between(range(1, n_samples), min_biased_0_risk_error_per_samples, max_biased_0_risk_error_per_samples, alpha=.3)
285
+
286
+ plt.plot(
287
+ range(1, n_samples), mean_biased_risk_error_per_samples, label="Biased"
288
+ )
289
+ plt.fill_between(
290
+ range(1, n_samples),
291
+ min_biased_risk_error_per_samples,
292
+ max_biased_risk_error_per_samples,
293
+ alpha=0.3,
294
+ )
295
+ plt.ylim(
296
+ np.min(min_inference_risk_error_per_samples),
297
+ np.max(max_biased_risk_error_per_samples),
298
+ )
299
+
300
+ plt.hlines(y=0, xmin=0, xmax=n_samples, colors="black", linestyles="--", lw=0.3)
301
+
302
+ plt.xlabel("Number of samples")
303
+ plt.ylabel("Risk estimation error")
304
+ plt.legend()
305
+ plt.title(f"Risk estimation error at risk-level={rl}")
306
+ plt.gcf().set_size_inches(4, 3)
307
+ plt.legend(loc="lower right")
308
+ plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.svg"))
309
+ plt.savefig(fname=os.path.join(path_save, f"risk_level_{rl}.png"))
310
+ plt.clf()
311
+ # plt.show()
312
+
313
+
314
+ def compute_stats(metrics, n_samples_mean_cost=4):
315
+ biased_risk_estimate = {}
316
+ for key in metrics["cost"].keys():
317
+ if key == "inference":
318
+ continue
319
+ risk = metrics["risk"][key]
320
+ mean_cost = metrics["cost"][key][:, :n_samples_mean_cost].mean(1)
321
+ risk_error = mean_cost - risk
322
+ biased_risk_estimate[key] = {}
323
+ (
324
+ biased_risk_estimate[key]["mean"],
325
+ biased_risk_estimate[key]["std"],
326
+ biased_risk_estimate[key]["ste"],
327
+ ) = masked_mean_std_ste(risk_error, metrics["mask"].any(1))
328
+
329
+ (
330
+ biased_risk_estimate[key]["mean_abs"],
331
+ biased_risk_estimate[key]["std_abs"],
332
+ biased_risk_estimate[key]["ste_abs"],
333
+ ) = masked_mean_std_ste(risk_error.abs(), metrics["mask"].any(1))
334
+
335
+ risk_stats = {}
336
+ for key in metrics["risk"].keys():
337
+ risk_stats[key] = {}
338
+ (
339
+ risk_stats[key]["mean"],
340
+ risk_stats[key]["std"],
341
+ risk_stats[key]["ste"],
342
+ ) = masked_mean_std_ste(metrics["risk"][key], metrics["mask"].any(1))
343
+
344
+ cost_stats = {}
345
+ for key in metrics["cost"].keys():
346
+ cost_stats[key] = {}
347
+ (
348
+ cost_stats[key]["mean"],
349
+ cost_stats[key]["std"],
350
+ cost_stats[key]["ste"],
351
+ ) = masked_mean_std_ste(
352
+ metrics["cost"][key], metrics["mask"].any(-1, keepdim=True)
353
+ )
354
+
355
+ distance_stats = {}
356
+ for key in metrics["distance"].keys():
357
+ distance_stats[key] = {"FDE": {}, "ADE": {}, "minFDE": {}, "minADE": {}}
358
+ (
359
+ distance_stats[key]["FDE"]["mean"],
360
+ distance_stats[key]["FDE"]["std"],
361
+ distance_stats[key]["FDE"]["ste"],
362
+ ) = masked_mean_std_ste(
363
+ metrics["distance"][key][..., -1], metrics["mask"][:, None, -1]
364
+ )
365
+ mean_dist = masked_mean_dim(
366
+ metrics["distance"][key], metrics["mask"][:, None, :], -1
367
+ )
368
+ (
369
+ distance_stats[key]["ADE"]["mean"],
370
+ distance_stats[key]["ADE"]["std"],
371
+ distance_stats[key]["ADE"]["ste"],
372
+ ) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1, keepdim=True))
373
+ for i in [6, 16, 32]:
374
+ distance_stats[key]["minFDE"][i] = {}
375
+ min_dist, _ = metrics["distance"][key][:, :i, -1].min(1)
376
+ (
377
+ distance_stats[key]["minFDE"][i]["mean"],
378
+ distance_stats[key]["minFDE"][i]["std"],
379
+ distance_stats[key]["minFDE"][i]["ste"],
380
+ ) = masked_mean_std_ste(min_dist, metrics["mask"][:, -1])
381
+ distance_stats[key]["minADE"][i] = {}
382
+ mean_dist, _ = masked_mean_dim(
383
+ metrics["distance"][key][:, :i], metrics["mask"][:, None, :], -1
384
+ ).min(1)
385
+ (
386
+ distance_stats[key]["minADE"][i]["mean"],
387
+ distance_stats[key]["minADE"][i]["std"],
388
+ distance_stats[key]["minADE"][i]["ste"],
389
+ ) = masked_mean_std_ste(mean_dist, metrics["mask"].any(-1))
390
+ return {
391
+ "risk": risk_stats,
392
+ "biased_risk_estimate": biased_risk_estimate,
393
+ "cost": cost_stats,
394
+ "distance": distance_stats,
395
+ }
396
+
397
+
398
+ def print_stats(stats, n_samples_mean_cost=4):
399
+ slash = "\\"
400
+ brace_open = "{"
401
+ brace_close = "}"
402
+ print("\\begin{tabular}{lccccc}")
403
+ print("\\hline")
404
+ print(
405
+ f"Predictive Model & ${slash}sigma$ & minFDE(16) & FDE (1) & Risk est. error ({n_samples_mean_cost}) & Risk est. $|$error$|$ ({n_samples_mean_cost}) {slash}{slash}"
406
+ )
407
+ print("\\hline")
408
+
409
+ for key in stats["distance"].keys():
410
+ strg = (
411
+ f" ${stats['distance'][key]['minFDE'][16]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['minFDE'][16]['ste']:.2f}${brace_close}"
412
+ + f"& ${stats['distance'][key]['FDE']['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['distance'][key]['FDE']['ste']:.2f}${brace_close}"
413
+ )
414
+
415
+ if key == "inference":
416
+ strg = (
417
+ "Unbiased CVAE & "
418
+ + f"{slash}scriptsize{brace_open}NA{brace_close} &"
419
+ + strg
420
+ + f"& {slash}scriptsize{brace_open}NA{brace_close} & {slash}scriptsize{brace_open}NA{brace_close} {slash}{slash}"
421
+ )
422
+ print(strg)
423
+ print("\\hline")
424
+ else:
425
+ strg = (
426
+ "Biased CVAE & "
427
+ + f"{key[7:]} & "
428
+ + strg
429
+ + f"& ${stats['biased_risk_estimate'][key]['mean']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste']:.2f}${brace_close}"
430
+ + f"& ${stats['biased_risk_estimate'][key]['mean_abs']:.2f}$ {slash}scriptsize{brace_open}${slash}pm {stats['biased_risk_estimate'][key]['ste_abs']:.2f}${brace_close}"
431
+ + f"{slash}{slash}"
432
+ )
433
+ print(strg)
434
+ print("\\hline")
435
+ print("\\end{tabular}")
436
+
437
+
438
+ def main(
439
+ log_path,
440
+ force_recompute,
441
+ n_samples_risk=256,
442
+ n_samples_stats=32,
443
+ n_samples_plot=16,
444
+ args_to_parser=[],
445
+ ):
446
+ # Overwrite sys.argv so it doesn't mess up the parser.
447
+ sys.argv = sys.argv[0:1] + args_to_parser
448
+ working_dir = os.path.dirname(os.path.realpath(__file__))
449
+ config_path = os.path.join(
450
+ working_dir, "..", "..", "risk_biased", "config", "learning_config.py"
451
+ )
452
+ waymo_config_path = os.path.join(
453
+ working_dir, "..", "..", "risk_biased", "config", "waymo_config.py"
454
+ )
455
+ cfg = config_argparse([config_path, waymo_config_path])
456
+
457
+ file_path = os.path.join(log_path, f"metrics_{cfg.load_from}.pickle")
458
+ fig_path = os.path.join(log_path, f"plots_{cfg.load_from}")
459
+ if not os.path.exists(fig_path):
460
+ os.makedirs(fig_path)
461
+
462
+ risk_levels = [0, 0.3, 0.5, 0.8, 0.95, 1]
463
+ cost = TTCCostTorch(TTCCostParams.from_config(cfg))
464
+ risk_estimator = get_risk_estimator(cfg.risk_estimator)
465
+ n_samples_mean_cost = 4
466
+
467
+ if not os.path.exists(file_path) or force_recompute:
468
+ with torch.no_grad():
469
+ if cfg.seed is not None:
470
+ seed_everything(cfg.seed)
471
+
472
+ predictor, dataloaders, cfg = load_from_config(cfg)
473
+ device = torch.device(cfg.gpus[0])
474
+ predictor = predictor.to(device)
475
+
476
+ val_loader = dataloaders.val_dataloader(shuffle=False, drop_last=False)
477
+
478
+ # This loops over batches in the validation dataset
479
+ beg = 0
480
+ metrics_all = None
481
+ for val_batch in tqdm(val_loader):
482
+ end = beg + val_batch[0].shape[0]
483
+ metrics = compute_metrics(
484
+ predictor=predictor,
485
+ batch=to_device(val_batch, device),
486
+ cost=cost,
487
+ risk_levels=risk_levels,
488
+ risk_estimator=risk_estimator,
489
+ dt=cfg.dt,
490
+ unnormalizer=dataloaders.unnormalize_trajectory,
491
+ n_samples_risk=n_samples_risk,
492
+ n_samples_stats=n_samples_stats,
493
+ )
494
+ if metrics_all is None:
495
+ metrics_all = metrics
496
+ else:
497
+ metrics_all = cat_metrics(metrics_all, metrics)
498
+ beg = end
499
+ with open(file_path, "wb") as handle:
500
+ pickle.dump(metrics_all, handle)
501
+ else:
502
+ print(f"Loading pre-computed metrics from {file_path}")
503
+ with open(file_path, "rb") as handle:
504
+ metrics_all = CPU_Unpickler(handle).load()
505
+
506
+ stats = compute_stats(metrics_all, n_samples_mean_cost=n_samples_mean_cost)
507
+ print_stats(stats, n_samples_mean_cost=n_samples_mean_cost)
508
+ plot_risk_error(metrics_all, risk_levels, risk_estimator, n_samples_plot, fig_path)
509
+
510
+
511
+ if __name__ == "__main__":
512
+ # main("./logs/002/", False, 256, 32, 16)
513
+ # Fire turns the main function into a script, then the risk_biased module argparse reads the other arguments.
514
+ # Thus, the way to use it would be:
515
+ # >python compute_stats.py <path to existing log dir> <Force recompute> <n_samples_risk> <n_samples_stats> <n_samples_plot> <other argparse arguments, example --load_from 1uail32>
516
+
517
+ # This is a hack to separate the Fire script args from the argparse arguments
518
+ args_to_parser = sys.argv[len(signature(main).parameters) :]
519
+ partial_main = partial(main, args_to_parser=args_to_parser)
520
+ sys.argv = sys.argv[: len(signature(main).parameters)]
521
+
522
+ # Runs the main as a script
523
+ fire.Fire(partial_main)
scripts/eval_scripts/draw_cost.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ from mmcv import Config
5
+ import numpy as np
6
+ from pytorch_lightning.utilities.seed import seed_everything
7
+
8
+ from risk_biased.scene_dataset.scene import RandomScene, RandomSceneParams
9
+ from risk_biased.scene_dataset.scene_plotter import ScenePlotter
10
+ from risk_biased.utils.cost import (
11
+ DistanceCostNumpy,
12
+ DistanceCostParams,
13
+ TTCCostNumpy,
14
+ TTCCostParams,
15
+ )
16
+
17
+ if __name__ == "__main__":
18
+ working_dir = os.path.dirname(os.path.realpath(os.path.join(__file__, "..")))
19
+ config_path = os.path.join(
20
+ working_dir, "..", "risk_biased", "config", "learning_config.py"
21
+ )
22
+ config = Config.fromfile(config_path)
23
+ if config.seed is not None:
24
+ seed_everything(config.seed)
25
+ ped_speed = 2
26
+ is_torch = False
27
+
28
+ fig, ax = plt.subplots(
29
+ 3, 4, sharex=True, sharey=True, tight_layout=True, subplot_kw={"aspect": 1}
30
+ )
31
+
32
+ scene_params = RandomSceneParams.from_config(config)
33
+ scene_params.batch_size = 1
34
+ for ii in range(9):
35
+ test_scene = RandomScene(
36
+ scene_params,
37
+ is_torch=is_torch,
38
+ )
39
+ dist_cost = DistanceCostNumpy(DistanceCostParams.from_config(config))
40
+ ttc_cost = TTCCostNumpy(TTCCostParams.from_config(config))
41
+
42
+ nx = 1000
43
+ ny = 100
44
+ x, y = np.meshgrid(
45
+ np.linspace(-test_scene.ego_length, test_scene.road_length, nx),
46
+ np.linspace(test_scene.bottom, test_scene.top, ny),
47
+ )
48
+
49
+ i = 2 - (int(ii >= 6) + int(ii >= 3))
50
+ j = ii % 3
51
+ vx = float(ii % 3 - 1)
52
+ vy = float((ii >= 6)) - float(ii <= 2)
53
+ print(f"horizontal velocity {vx}")
54
+ print(f"vertical velocity {vy}")
55
+ norm = np.maximum(np.sqrt(vx * vx + vy * vy), 1)
56
+ vx = vx / norm * np.ones([nx * ny, 1])
57
+ vy = vy / norm * np.ones([nx * ny, 1])
58
+ v_ped = ped_speed * np.stack((vx, vy), -1)
59
+ v_ego = np.array([[[test_scene.ego_ref_speed, 0]]])
60
+
61
+ p_init = np.stack((x, y), -1).reshape((nx * ny, 2))
62
+ p_final = p_init + v_ped[:, 0, :] * test_scene.time_scene
63
+ len_traj = 30
64
+ ped_trajs = np.linspace(p_init, p_final, len_traj, axis=1)
65
+ ego_traj = np.linspace(
66
+ [[0, 0]],
67
+ [test_scene.ego_ref_speed * test_scene.time_scene, 0],
68
+ len_traj,
69
+ axis=1,
70
+ )
71
+
72
+ cost, _ = ttc_cost(ego_traj, ped_trajs, v_ego, v_ped)
73
+
74
+ cost = cost.reshape(ny, nx)
75
+ colorbar = ax[i][j].pcolormesh(x, y, cost, cmap="RdBu")
76
+ plotter = ScenePlotter(test_scene, ax=ax[i][j])
77
+ plotter.plot_road()
78
+
79
+ fig.subplots_adjust(wspace=0.1, hspace=0.1)
80
+ fig.tight_layout()
81
+ fig.colorbar(colorbar, ax=ax.ravel().tolist())
82
+ for a in ax[:, -1]:
83
+ a.remove()
84
+ plt.show()